神仙 后缀数组 + 容斥 + DP 题目,非常巧妙。
先尝试转化为相对能做的限制条件。我们通过后缀数组容易得到每个字符之间的大小关系。具体地,考虑后缀数组 $p$ 以及排名数组 $\operatorname{rk}$,若 $\operatorname{rk}_{p_i+1}<\operatorname{rk}_{p_{i+1}+1}$,那么有 $s_{p_i}\le s_{p_{i+1}}$,否则有 $s_{p_i}< s_{p_{i+1}}$。最终我们会得到若干不等式链形如 $s_{p_1}\otimes s_{p_2}\otimes\cdots \otimes s_{p_n} $,其中 $\otimes $ 为 $<,\le $ 之一 。不难得出该条件是充要的,那么所有这样的形成的不等式链与后缀数组构成双射。
考虑合法的不等式链的条件。先将不等式链用 $<$ 划分为若干段,从小到大考虑每一种类,每次尝试将当前种类字符填在一段上,未填满就全部填上,填满就把当前种类剩下的字符扔掉,重复该过程直到填完整个字符串。我们尝试通过这个过程入手求解。
当我们尝试按照上面过程计算方案,设每一段大小分别为 $a_i$,得到方案为 $\displaystyle \frac{n!}{\prod_{i}a_i!}$。先对 $\displaystyle \frac{1}{\prod_{i}a_i!}$ 求解,最后再乘上 $n$。
但是这样显然会算重,因为中间可能会存在某些 $\le$ 的位置被当作 $<$ 算过一次。于是考虑容斥,考虑将若干个 $<$ 容斥为 $\le $ 即可。于是考虑一个经典的容斥 DP,设状态为 $f(i,j,k)$,表示当前考虑到从小到大第 $i$ 种字符,一共填了 $j$ 个字符,当前 $\le$ 段已经填了 $k$ 个字符,有下面四种转移:
- 直接将全部字符接在当前段上:$\displaystyle f(i,j,k)\gets f(i-1,j-c_i,k-c_i)$。
- 将一部分字符接在当前段,再添加一个 $<$,开启下一段:$\displaystyle f(i,j,0)\gets \frac{1}{k!}\sum_{l=1}^{c_i} f(i-1,j-l,k-l)$。
- 将一部分字符接在当前段,再添加一个 $<$,不过容斥为 $\le $,继续填当前段: $\displaystyle f(i,j,k)\gets -\sum_{l=1}^{c_i} f(i-1,j-l,k-l)$。
转移前缀和优化即可,时间复杂度为 $\mathcal O(n^2m)$。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
| #include<iostream> #include<cstdio> using namespace std; const int N=5e2+10,mod=1e9+7; template<class T> inline void read(T &x) { x=0;int f=0; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); if(f)x=~x+1; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) { read(x),read(x1...); } int n,m; int c[N],fac[N],ifac[N]; int f[N][N],sum[N][N];
inline void adj(int &x){x+=x>>31&mod;} inline int qpow(int x,int k=mod-2) { int res=1; while(k) { if(k&1)res=1ll*res*x%mod; x=1ll*x*x%mod; k>>=1; } return res; } int main() { freopen("suffix.in","r",stdin); freopen("suffix.out","w",stdout); read(n,m); for(int i=1;i<=m;i++) { read(c[i]); if(!c[i])i--,m--; } fac[0]=1; for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod; ifac[n]=qpow(fac[n]); for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod; f[0][0]=1; int ans=0; for(int i=1;i<=m;i++) { int c=::c[i]; for(int j=0;j<=n;j++) for(int k=0;k<=j;k++) { sum[j][k]=f[j][k]; if(j&&k)adj(sum[j][k]+=sum[j-1][k-1]-mod); f[j][k]=0; } for(int j=1;j<=n;j++) for(int k=1;k<=j;k++) { f[j][k]=mod-sum[j-1][k-1]; adj(f[j][0]+=1ll*ifac[k]*sum[j-1][k-1]%mod-mod); if(j>=c&&k>=c)adj(f[j][k]+=sum[j-c][k-c]%mod-mod); if(j>c&&k>c)adj(f[j][0]-=1ll*ifac[k]*sum[j-c-1][k-c-1]%mod); } adj(ans+=f[n][0]-mod); } printf("%d\n",1ll*ans*fac[n]%mod); return 0; }
|