loj2554题解

首先注意到给定的区间之间要么不相交,要么嵌套,否则肯定不合法,而且一定有一个覆盖整个序列的区间. 据此我们可以发现区间之间嵌套的关系形成了一个树结构,树中每个点是一个极大的连续的区间. 因此我们只需求出$f _i$表示长为$i+1$的,不存在任何不包含最后一位的连续区间的,排列的个数.

假设$a$是$1,2,\ldots,n$的一个排列,令$p _i$表示$i$在$a$中的位置,易知$a$中一个连续的区间对应着$p$中一个连续的区间,因此$a$满足上一段所述条件等价于$p$中不存在一个不包含最大值的连续区间.

考虑跟排列有关的递推我只知道有两种思路,把$n+1$插入或者在序列某位放一个$i$,并把前面所有$\ge i$的数$+1$. 我一开始想的是后一种思路,结果没想出来. 考虑前一种思路,因为$p$中不存在一个不包含最大值的连续区间,为方便起见插入$1$而不是插入$n+1$,接下来要讨论两种情况. 为方便起见,下面先假设$n\ge 2$.

如果插入前的序列是合法的,那么插入前的序列有$f _{n-1}$种可能,而插入的位置可以是除了插入前最小值旁边的任何位置,共$n-1$种,因此这种情况的贡献为$(n-1)f _{n-1}$.

如果插入前序列不合法,那么$1$一定是放在了一个连续区间中间,设它所在的极长连续区间长度为$i$,那么这段区间插入$1$后不连续等价于它插入$1$,离散化后不存在不包含最小值的连续区间,等价于不存在不包含最d大值的连续区间,所以这个连续区间插入$1$之后共有$f _i$种可能. 把这段区间缩成一个数,然后把得到的序列离散化,剩下的序列长度为$n-i+1$,它也不能有不包含最小值的连续区间,共$f _{n-i}$种可能. 注意到插入$1$的那个连续区间在离散化之后的权值不能为最小值也不能为最大值,因而有$n-i-1$种可能. 综上,这一部分的贡献为

$$\sum _{i=2}^{n-1}f _if _{n-i}(n-i-1)=\sum _{i=2}^{n-2}(i-1)f _if _{n-i}$$

于是我们可以列出递推式

$$\begin{aligned}f _n=&\sum _{i=2}^{n-2}(i-1)f _if _{n-i}+(n-1)f _{n-1}\\=&\sum _{i=2}^{n-1}(i-1)f _if _{n-i}-(n-3)f _{n-1}\end{aligned}$$

把下标范围弄得对称一些,并加入一些修改使其对$\forall n\in \mathbb N$都成立

$$f _n=\sum _{i=1}^{n-1}(i-1)f _if _{n-i}-(n-3)f _{n-1}+[n=0]$$

这个东西就可以用分治来计算了.

这个卷积是自己卷自己,但也是可以分治的. 对于每个$(j-1)f _jf _k$,我们在$j,k$中较大一项被计算出来的时候,统计它的贡献. 把分治的长度补全为$2$的幂,假设当前的分治区间是$[l,r)$,$m=\left\lfloor\frac{l+r}{2}\right\rfloor$,统计$[l,m)$中的$f$的贡献,分两种情况讨论:

  • $l\neq 0$. 对于$[l,m)$中的每个$i$,要取一个$j$使得它们对$[m,r)$有贡献,显然$j\in[0,r-l)$,而此时必有$r-l\le l$,直接把$[l,m)$中的元素和$[0,r-l)$中的元素卷起来统计贡献就可以了.
  • $l=0$. 直接$[l,m)$中的元素卷上$[l,m)$中的元素.

第一种情况实现精细的话可以用5次而不是6次ntt,说不定还可以更少.

这种题卷积的下标范围一定要想办法弄对称,这样细节处理上会方便很多.

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ele int
#define ll long long
using namespace std;
#define maxn (1<<17)
#define MOD 998244353
#define g 3
ele n,a[maxn],f[maxn];
inline ele& add(ele&a,ele b){
return a=(a+b>=MOD?a+b-MOD:a+b);
}
inline ele pw(ele a,ele x){
ele ans=1,tmp=a%MOD;
for (; x; x>>=1,tmp=(ll)tmp*tmp%MOD)
if (x&1) ans=(ll)ans*tmp%MOD;
return ans;
}
inline void ntt(ele K,ele n,ele *y){
static ele f[maxn];
f[0]=0;
for (int i=1; i<n; ++i){
f[i]=f[i>>1]>>1;
if (i&1) f[i]+=n>>1;
if (i<f[i]) swap(y[i],y[f[i]]);
}
for (int p=1; p<n; p<<=1){
ele o=pw(g,(MOD-1)/p/2);
o=~K?o:pw(o,MOD-2);
for (int i=0; i<n; i+=(p<<1)){
ele o1=1;
for (int j=i; j<i+p; ++j,o1=(ll)o1*o%MOD){
ele u=y[j],v=(ll)y[j+p]*o1%MOD;
y[j]=(u+v)%MOD;
y[j+p]=(u-v+MOD)%MOD;
}
}
}
if (!~K){
ele invn=pw(n,MOD-2);
for (int i=0; i<n; ++i) y[i]=(ll)y[i]*invn%MOD;
}
}
void solve(ele *f,ele l,ele r){
if (r-l<=1) return;
ele mid=(l+r)>>1,tmp=(r-l)<<1;
solve(f,l,mid);
static ele t1[maxn],t2[maxn],t3[maxn],t4[maxn];
if (l){
memset(t1,0,sizeof(ele)*tmp);
for (int i=l; i<mid; ++i) t1[i-l]=(ll)f[i]*(i-1)%MOD;
memset(t2,0,sizeof(ele)*tmp);
for (int i=1; i<r-l; ++i) t2[i]=f[i];
memset(t3,0,sizeof(ele)*tmp);
for (int i=2; i<r-l; ++i) t3[i]=(ll)f[i]*(i-1)%MOD;
memset(t4,0,sizeof(ele)*tmp);
for (int i=l; i<mid; ++i) t4[i-l]=f[i];
ntt(1,tmp,t1); ntt(1,tmp,t2); ntt(1,tmp,t3); ntt(1,tmp,t4);
for (int i=0; i<tmp; ++i)
t1[i]=((ll)t1[i]*t2[i]%MOD+(ll)t3[i]*t4[i]%MOD)%MOD;
ntt(-1,tmp,t1);
for (int i=mid; i<r; ++i) add(f[i],t1[i-l]);
}
else{
memset(t1,0,sizeof(ele)*tmp);
for (int i=1; i<mid; ++i) t1[i]=(ll)f[i]*(i-1)%MOD;
memset(t2,0,sizeof(ele)*tmp);
for (int i=1; i<mid; ++i) t2[i]=f[i];
ntt(1,tmp,t1); ntt(1,tmp,t2);
for (int i=0; i<tmp; ++i) t1[i]=(ll)t1[i]*t2[i]%MOD;
ntt(-1,tmp,t1);
for (int i=mid; i<r; ++i) add(f[i],t1[i]);
}
add(f[mid],MOD-(ll)f[mid-1]*(mid+MOD-3)%MOD);
solve(f,mid,r);
}
ele calc(ele i){
if (a[i]==1) return 1;
ele j=i-1,ans=1,tmp=0;
while (j>i-a[i]){
ans=(ll)ans*calc(j)%MOD;
j-=a[j];
++tmp;
}
if (j<i-a[i]) return 0;
ans=(ll)ans*f[tmp]%MOD;
return ans;
}
int main(){
ele T;
scanf("%d%d",&T,&n);
f[0]=1;
ele tmp=1;
while (tmp<=n) tmp<<=1;
solve(f,0,tmp);
while (T--){
for (int i=0; i<n; ++i) scanf("%d",a+i);
if (a[n-1]!=n) puts("0");
else printf("%d\n",calc(n-1));
}
return 0;
}