Liveddd's Blog

愛にできることはまだあるかい

P6944 [ICPC2018 WF] Gem Island

不错的 DP 题目,中间用 二项式反演 或者 min-max 容斥处理。加强版做法也比较厉害。

真的很厉害这道题,感觉学到了很多组合数技巧。

普通版 $(n,d\le 500,r\le n )$

设 $c_i$ 为最后每个人拥有的宝石。第 $i$ 个人拥有的第 $j$ 个宝石由前面 $j-1$ 个宝石分裂得到,所以方案为 $(c_i-1)!$。再考虑每天钦定分裂 哪个人的宝石,方案数为 $\displaystyle \binom{d}{c_1-1,c_2-1,\cdots,c_n-1}=\dfrac{d!}{\prod_{i}(c_i-1)!}$二者相乘,得到 $d!$。说明每种情况出现的次数都是 $d!$。我们只需要对每种情况前 $r$ 大的 $c_i$ 求和即可。

对于这种形式,我们容易想到使用分拆数 DP 进行求解。我们设计状态 $f(i,j)$ 表示最大的 $c_x$ 有 $i$ 个,总和为 $j$ 的方案数。我们每次枚举将最大的 $i$ 个 $c_x$ 中选取 $k$ 个加一。转移为:

我们还需要计算前 $r$ 大的和,继续设计状态 $g(i,j)$ 意义与 $f(i,j)$ 类似,但是表示表示前 $r$ 大 $c_x$ 之和。转移为:

最终答案为 $\displaystyle \dfrac{\sum_i g(i,d)}{\sum_i f(i,d)}$。总的时间复杂度为 $\mathcal O(nd^2)$。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include<iostream>
#include<cstdio>
using namespace std;
const int N=1e3+10;
template<class T>
inline void read(T &x)
{
x=0;bool f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,d,r;
double C[N][N],f[N][N],g[N][N];
inline void init(int n)
{
for(int i=0;i<=n;i++)
{
C[i][0]=1;
for(int j=1;j<=i;j++)C[i][j]=C[i-1][j-1]+C[i-1][j];
}
}
int main()
{
read(n,d,r);
init(n+d);
f[n][0]=1;
for(int i=n;i;i--)
for(int j=0;j<d;j++)
for(int k=1;k<=i&&j+k<=d;k++)
f[k][j+k]+=f[i][j]*C[i][k],g[k][j+k]+=(g[i][j]+min(r,k)*f[i][j])*C[i][k];
double ans=0;
for(int i=1;i<=n;i++)ans+=g[i][d];
printf("%.9lf\n",ans/C[n+d-1][n-1]+r);
return 0;
}

能不能再给力一点?

加强版 $(n,d\le 1.5\times 10^7,r\le n ,p=998244353)$

有两个 DP 数组不方便考虑,我们尝试到值上进行考虑。将 $c_i$ 的贡献拆成 $\sum_{j=1}^d[c_i\ge j]$。我们设 $f(i,j)$ 表示恰好有 $i$ 个数大于等于 $j$ 的方案书。最终答案为 $\displaystyle \sum_{i=1}^n \sum_{j=1}^d \min(i,r) f(i,j)$。

注意到“恰好”的限制,容易想到容斥。直接钦定 有 $i$ 个数大于等于 $j$,设为 $g(i,j)$,利用隔板法得到 $\displaystyle g(i,j)=\binom{n}{i}\binom{d-ij+n-1}{n-1}$。于是根据 $\displaystyle g(i,j)=\sum_{k=i}^n\binom{k}{i}f(k,j)$,直接二项式反演得到:

对答案简单推导:

对于每个 $k$ 与处理出 $h(k)=\sum_{j=1}^d g(k,j)$ 即可,时间复杂度可以做到 $\mathcal O(n^2)$。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include<iostream>
#include<cstdio>
using namespace std;
const int N=5e3+10,mod=998244353;
template<class T>
inline void read(T &x)
{
x=0;bool f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,d,r;
int fac[N<<1],ifac[N<<1];
int h[N];
int ans;
inline int adj(int x){return x>=mod?x-mod:x;}
inline void upd(int &x,int y){x+=y;x=x>=mod?x-mod:x;}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
inline void init(int n)
{
fac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
ifac[n]=qpow(fac[n]);
for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod;
}
inline int C(int n,int m)
{
if(m>n)return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
int main()
{
read(n,d,r);
init(n+d);
for(int i=1;i<=n;i++)
for(int j=1;i*j<=d;j++)
upd(h[i],1ll*C(n,i)*C(d-i*j+n-1,n-1)%mod);
for(int i=1;i<=n;i++)
{
int sum=0;
for(int j=i;j<=n;j++)
{
int res=1ll*C(j,i)*h[j]%mod;
upd(sum,(j-i&1)?mod-res:res);
}
upd(ans,1ll*min(i,r)*sum%mod);
}
printf("%d\n",adj(1ll*ans*qpow(C(n+d-1,n-1))%mod+r));
return 0;
}

但是仍然不够优秀。$h(k)$ 中的组合数不好拆,但是注意到 $\displaystyle z(k)=\binom{d-k+n-1}{n-1},h(k)=\sum z(jk)$ ,于是直接用 Dirichlet 后缀和就可以求出 $h(k)$。

接下来直接枚举 $k$:

我们需要快速计算:

看起来比较困难,但还是可以做。考虑上指标反转 $\displaystyle \binom{n}{m} = (-1)^m \binom{m-n-1}{m}$。再利用斜求和 $\displaystyle \sum_{i=1}^k\binom{i+j}{i}=\binom{k+1+j}{k}$,得到:

因为 $k>0$,所以 $\binom{0}{k}=0$。然后再用一次上指标反转,得:

这是一个与上面类似得求和,不停地使用上指标反转,得:

于是这一部分可以直接 $\mathcal O(1)$ 计算。总的时间复杂度为 $\mathcal O(n+d\log\log d)$。

Code

*

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<iostream>
#include<cstdio>
using namespace std;
const int N=1.5e7+10,mod=998244353;
template<class T>
inline void read(T &x)
{
x=0;bool f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,d,r;
int fac[N<<1],ifac[N<<1];
bool vis[N];
int tot,prime[N];
int h[N];

int ans;
inline int adj(int x){return x>=mod?x-mod:x;}
inline void upd(int &x,int y){x+=y;x=x>=mod?x-mod:x;}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
void sieve(int n)
{
for(int i=2;i<=n;i++)
{
if(!vis[i])prime[++tot]=i;
for(int j=1;j<=tot&&1ll*i*prime[j]<=n;j++)
{
vis[i*prime[j]]=1;
if(!(i%prime[j]))break;
}
}
}
inline void init(int n)
{
fac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
ifac[n]=qpow(fac[n]);
for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod;
}
inline int C(int n,int m)
{
if(m>n)return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
int main()
{
read(n,d,r);
init(n+d);
sieve(max(n,d));
for(int i=1;i<=d;i++)h[i]=C(d-i+n-1,n-1);
for(int i=tot;i;i--)
for(int j=d/prime[i];j;j--)
upd(h[j],h[j*prime[i]]);
for(int i=1;i<=n;i++)
{
h[i]=1ll*h[i]*C(n,i)%mod;
int res=C(i-2,r-1);
if(i==1)res=mod-1;
else res=(r&1?mod-res:res);
res=i&1?mod-res:res;
upd(ans,1ll*h[i]*res%mod);
}
printf("%d\n",adj(1ll*ans*qpow(C(n+d-1,n-1))%mod+r));
return 0;
}