不错的 DP 题目,中间用 二项式反演 或者 min-max 容斥处理。加强版做法也比较厉害。
真的很厉害这道题,感觉学到了很多组合数技巧。
普通版 $(n,d\le 500,r\le n )$
设 $c_i$ 为最后每个人拥有的宝石。第 $i$ 个人拥有的第 $j$ 个宝石由前面 $j-1$ 个宝石分裂得到,所以方案为 $(c_i-1)!$。再考虑每天钦定分裂 哪个人的宝石,方案数为 $\displaystyle \binom{d}{c_1-1,c_2-1,\cdots,c_n-1}=\dfrac{d!}{\prod_{i}(c_i-1)!}$二者相乘,得到 $d!$。说明每种情况出现的次数都是 $d!$。我们只需要对每种情况前 $r$ 大的 $c_i$ 求和即可。
对于这种形式,我们容易想到使用分拆数 DP 进行求解。我们设计状态 $f(i,j)$ 表示最大的 $c_x$ 有 $i$ 个,总和为 $j$ 的方案数。我们每次枚举将最大的 $i$ 个 $c_x$ 中选取 $k$ 个加一。转移为:
我们还需要计算前 $r$ 大的和,继续设计状态 $g(i,j)$ 意义与 $f(i,j)$ 类似,但是表示表示前 $r$ 大 $c_x$ 之和。转移为:
最终答案为 $\displaystyle \dfrac{\sum_i g(i,d)}{\sum_i f(i,d)}$。总的时间复杂度为 $\mathcal O(nd^2)$。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| #include<iostream> #include<cstdio> using namespace std; const int N=1e3+10; template<class T> inline void read(T &x) { x=0;bool f=0; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); if(f)x=~x+1; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) { read(x),read(x1...); } int n,d,r; double C[N][N],f[N][N],g[N][N]; inline void init(int n) { for(int i=0;i<=n;i++) { C[i][0]=1; for(int j=1;j<=i;j++)C[i][j]=C[i-1][j-1]+C[i-1][j]; } } int main() { read(n,d,r); init(n+d); f[n][0]=1; for(int i=n;i;i--) for(int j=0;j<d;j++) for(int k=1;k<=i&&j+k<=d;k++) f[k][j+k]+=f[i][j]*C[i][k],g[k][j+k]+=(g[i][j]+min(r,k)*f[i][j])*C[i][k]; double ans=0; for(int i=1;i<=n;i++)ans+=g[i][d]; printf("%.9lf\n",ans/C[n+d-1][n-1]+r); return 0; }
|
能不能再给力一点?
加强版 $(n,d\le 1.5\times 10^7,r\le n ,p=998244353)$
有两个 DP 数组不方便考虑,我们尝试到值上进行考虑。将 $c_i$ 的贡献拆成 $\sum_{j=1}^d[c_i\ge j]$。我们设 $f(i,j)$ 表示恰好有 $i$ 个数大于等于 $j$ 的方案书。最终答案为 $\displaystyle \sum_{i=1}^n \sum_{j=1}^d \min(i,r) f(i,j)$。
注意到“恰好”的限制,容易想到容斥。直接钦定 有 $i$ 个数大于等于 $j$,设为 $g(i,j)$,利用隔板法得到 $\displaystyle g(i,j)=\binom{n}{i}\binom{d-ij+n-1}{n-1}$。于是根据 $\displaystyle g(i,j)=\sum_{k=i}^n\binom{k}{i}f(k,j)$,直接二项式反演得到:
对答案简单推导:
对于每个 $k$ 与处理出 $h(k)=\sum_{j=1}^d g(k,j)$ 即可,时间复杂度可以做到 $\mathcal O(n^2)$。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| #include<iostream> #include<cstdio> using namespace std; const int N=5e3+10,mod=998244353; template<class T> inline void read(T &x) { x=0;bool f=0; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); if(f)x=~x+1; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) { read(x),read(x1...); } int n,d,r; int fac[N<<1],ifac[N<<1]; int h[N]; int ans; inline int adj(int x){return x>=mod?x-mod:x;} inline void upd(int &x,int y){x+=y;x=x>=mod?x-mod:x;} inline int qpow(int x,int k=mod-2) { int res=1; while(k) { if(k&1)res=1ll*res*x%mod; x=1ll*x*x%mod; k>>=1; } return res; } inline void init(int n) { fac[0]=1; for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod; ifac[n]=qpow(fac[n]); for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod; } inline int C(int n,int m) { if(m>n)return 0; return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod; } int main() { read(n,d,r); init(n+d); for(int i=1;i<=n;i++) for(int j=1;i*j<=d;j++) upd(h[i],1ll*C(n,i)*C(d-i*j+n-1,n-1)%mod); for(int i=1;i<=n;i++) { int sum=0; for(int j=i;j<=n;j++) { int res=1ll*C(j,i)*h[j]%mod; upd(sum,(j-i&1)?mod-res:res); } upd(ans,1ll*min(i,r)*sum%mod); } printf("%d\n",adj(1ll*ans*qpow(C(n+d-1,n-1))%mod+r)); return 0; }
|
但是仍然不够优秀。$h(k)$ 中的组合数不好拆,但是注意到 $\displaystyle z(k)=\binom{d-k+n-1}{n-1},h(k)=\sum z(jk)$ ,于是直接用 Dirichlet 后缀和就可以求出 $h(k)$。
接下来直接枚举 $k$:
我们需要快速计算:
看起来比较困难,但还是可以做。考虑上指标反转 $\displaystyle \binom{n}{m} = (-1)^m \binom{m-n-1}{m}$。再利用斜求和 $\displaystyle \sum_{i=1}^k\binom{i+j}{i}=\binom{k+1+j}{k}$,得到:
因为 $k>0$,所以 $\binom{0}{k}=0$。然后再用一次上指标反转,得:
这是一个与上面类似得求和,不停地使用上指标反转,得:
于是这一部分可以直接 $\mathcal O(1)$ 计算。总的时间复杂度为 $\mathcal O(n+d\log\log d)$。
Code
*
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
| #include<iostream> #include<cstdio> using namespace std; const int N=1.5e7+10,mod=998244353; template<class T> inline void read(T &x) { x=0;bool f=0; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); if(f)x=~x+1; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) { read(x),read(x1...); } int n,d,r; int fac[N<<1],ifac[N<<1]; bool vis[N]; int tot,prime[N]; int h[N];
int ans; inline int adj(int x){return x>=mod?x-mod:x;} inline void upd(int &x,int y){x+=y;x=x>=mod?x-mod:x;} inline int qpow(int x,int k=mod-2) { int res=1; while(k) { if(k&1)res=1ll*res*x%mod; x=1ll*x*x%mod; k>>=1; } return res; } void sieve(int n) { for(int i=2;i<=n;i++) { if(!vis[i])prime[++tot]=i; for(int j=1;j<=tot&&1ll*i*prime[j]<=n;j++) { vis[i*prime[j]]=1; if(!(i%prime[j]))break; } } } inline void init(int n) { fac[0]=1; for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod; ifac[n]=qpow(fac[n]); for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod; } inline int C(int n,int m) { if(m>n)return 0; return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod; } int main() { read(n,d,r); init(n+d); sieve(max(n,d)); for(int i=1;i<=d;i++)h[i]=C(d-i+n-1,n-1); for(int i=tot;i;i--) for(int j=d/prime[i];j;j--) upd(h[j],h[j*prime[i]]); for(int i=1;i<=n;i++) { h[i]=1ll*h[i]*C(n,i)%mod; int res=C(i-2,r-1); if(i==1)res=mod-1; else res=(r&1?mod-res:res); res=i&1?mod-res:res; upd(ans,1ll*h[i]*res%mod); } printf("%d\n",adj(1ll*ans*qpow(C(n+d-1,n-1))%mod+r)); return 0; }
|