loj2304题解

首先恰好为$k$不好处理,显然要转化为$\le k$的概率减去$\le k-1$的概率.

看上去就是个dp,让$f _i$表示共$i$列,安全泳池面积$\le k$的概率,考虑怎么递推. 假设最底下一行有一个格子是危险的,那么显然这一列两边互不干扰可以分开计算. 但是这里有一个问题,底下一行可能没有危险的格子. 于是我们把状态改为$f _{i,j}$表示共$i$列,下面$j$行均安全,第$j+1$行至少有一个危险的格子,且安全泳池面积$\le k$的概率. 记$g _{i,j}=\sum _{w\ge j}f _{i,w}$,即共$i$列,下面$j$行均安全,且安全泳池面积$\le k$的概率.

那么我们就可以枚举第$j+1$行最后一个危险格子的位置,得到递推式

$$f _{i,j}=\sum _{r\lt i,~r(j+1)\le k}g _{r,j+1}g _{i-r-1,j}q^j(1-q)$$

注意到$\forall ij>k,~f _{i,j}=0$,直接dp就可以得到70分.

考虑怎么拿剩下的30分. 显然当$i\gt k$时只需要考虑$f _{i,0}$,而且此时$f _{i,0}=g _{i,0}$. 而$r\le\frac{k}{j+1}\le k$,所以当$i\ge 2k+2$时,递推式可以写成

$$g _{i,0}=\sum _{r=0}^k(1-q)g _{r,1}g _{i-r-1,0}$$

这是一个常系数线性齐次递推式,前面暴力递推到$i=2k+1$,后面的用特征多项式可以做到$\mathcal O(k^2\log n)$.

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ele int
#define ll long long
using namespace std;
#define maxk 4010
#define MOD 998244353
ele n,K,x,y,q,f[maxk][maxk];
inline ele pw(ele a,ele x){
ele ans=1,tmp=a%MOD;
for (; x; x>>=1,tmp=(ll)tmp*tmp%MOD)
if (x&1) ans=(ll)ans*tmp%MOD;
return ans;
}
inline void mul(ele K,ele *a,ele *b){
static ele c[maxk];
memset(c,0,sizeof(c));
for (int i=0; i<=K; ++i)
for (int j=0; j<=K; ++j)
(c[i+j]+=(ll)a[i]*b[j]%MOD)%=MOD;
for (int i=K*2; i>K; --i)
for (int j=0; j<=K; ++j)
(c[i-j-1]+=(ll)c[i]*f[j][1]%MOD*(MOD+1-q)%MOD)%=MOD;
memcpy(a,c,sizeof(ele)*(K+1));
}
inline ele calc(ele K){
for (int i=0; i<=K+1; ++i) f[0][i]=1;
for (int i=1; i<=n && i<=K*2+1; ++i){
for (int j=0; i*j<=K; ++j){
f[i][j]=0;
ele tmp=(ll)pw(q,j)*(MOD+1-q)%MOD;
for (int r=0; r<i && r*(j+1)<=K; ++r){
ele t1=(ll)f[r][j+1]*tmp%MOD*f[i-r-1][j]%MOD;
(f[i][j]+=t1)%=MOD;
}
}
for (int j=K/i-1; ~j; --j) (f[i][j]+=f[i][j+1])%=MOD;
}
if (n<=K*2+1) return f[n][0];
static ele a[maxk],b[maxk];
memset(a,0,sizeof(a)); memset(b,0,sizeof(b));
a[0]=1; b[1]=1;
for (ele n1=n-K-1; n1; n1>>=1,mul(K,b,b))
if (n1&1) mul(K,a,b);
ele ans=0;
for (int i=0; i<=K; ++i)
(ans+=(ll)a[i]*f[K+1+i][0]%MOD)%=MOD;
return ans;
}
int main(){
scanf("%d%d%d%d",&n,&K,&x,&y);
q=(ll)x*pw(y,MOD-2)%MOD;
ele t1=calc(K);
ele t2=calc(K-1);
(t1+=MOD-t2)%=MOD;
printf("%d\n",t1);
return 0;
}