Liveddd's Blog

愛にできることはまだあるかい

LOJ#2988. 「CTSC2016」萨菲克斯 · 阿瑞

神仙 后缀数组 + 容斥 + DP 题目,非常巧妙。

先尝试转化为相对能做的限制条件。我们通过后缀数组容易得到每个字符之间的大小关系。具体地,考虑后缀数组 $p$ 以及排名数组 $\operatorname{rk}$,若 $\operatorname{rk}_{p_i+1}<\operatorname{rk}_{p_{i+1}+1}$,那么有 $s_{p_i}\le s_{p_{i+1}}$,否则有 $s_{p_i}< s_{p_{i+1}}$。最终我们会得到若干不等式链形如 $s_{p_1}\otimes s_{p_2}\otimes\cdots \otimes s_{p_n} $,其中 $\otimes $ 为 $<,\le $ 之一 。不难得出该条件是充要的,那么所有这样的形成的不等式链与后缀数组构成双射。

考虑合法的不等式链的条件。先将不等式链用 $<$ 划分为若干段,从小到大考虑每一种类,每次尝试将当前种类字符填在一段上,未填满就全部填上,填满就把当前种类剩下的字符扔掉,重复该过程直到填完整个字符串。我们尝试通过这个过程入手求解。

当我们尝试按照上面过程计算方案,设每一段大小分别为 $a_i$,得到方案为 $\displaystyle \frac{n!}{\prod_{i}a_i!}$。先对 $\displaystyle \frac{1}{\prod_{i}a_i!}$ 求解,最后再乘上 $n$。

但是这样显然会算重,因为中间可能会存在某些 $\le$ 的位置被当作 $<$ 算过一次。于是考虑容斥,考虑将若干个 $<$ 容斥为 $\le $ 即可。于是考虑一个经典的容斥 DP,设状态为 $f(i,j,k)$,表示当前考虑到从小到大第 $i$ 种字符,一共填了 $j$ 个字符,当前 $\le$ 段已经填了 $k$ 个字符,有下面四种转移:

  1. 直接将全部字符接在当前段上:$\displaystyle f(i,j,k)\gets f(i-1,j-c_i,k-c_i)$。
  2. 将一部分字符接在当前段,再添加一个 $<$,开启下一段:$\displaystyle f(i,j,0)\gets \frac{1}{k!}\sum_{l=1}^{c_i} f(i-1,j-l,k-l)$。
  3. 将一部分字符接在当前段,再添加一个 $<$,不过容斥为 $\le $,继续填当前段: $\displaystyle f(i,j,k)\gets -\sum_{l=1}^{c_i} f(i-1,j-l,k-l)$。

转移前缀和优化即可,时间复杂度为 $\mathcal O(n^2m)$。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include<iostream>
#include<cstdio>
using namespace std;
const int N=5e2+10,mod=1e9+7;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,m;
int c[N],fac[N],ifac[N];
int f[N][N],sum[N][N];

inline void adj(int &x){x+=x>>31&mod;}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
int main()
{
freopen("suffix.in","r",stdin);
freopen("suffix.out","w",stdout);
read(n,m);
for(int i=1;i<=m;i++)
{
read(c[i]);
if(!c[i])i--,m--;
}
fac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
ifac[n]=qpow(fac[n]);
for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod;
f[0][0]=1;
int ans=0;
for(int i=1;i<=m;i++)
{
int c=::c[i];
for(int j=0;j<=n;j++)
for(int k=0;k<=j;k++)
{
sum[j][k]=f[j][k];
if(j&&k)adj(sum[j][k]+=sum[j-1][k-1]-mod);
f[j][k]=0;
}
for(int j=1;j<=n;j++)
for(int k=1;k<=j;k++)
{
f[j][k]=mod-sum[j-1][k-1];
adj(f[j][0]+=1ll*ifac[k]*sum[j-1][k-1]%mod-mod);
if(j>=c&&k>=c)adj(f[j][k]+=sum[j-c][k-c]%mod-mod);
if(j>c&&k>c)adj(f[j][0]-=1ll*ifac[k]*sum[j-c-1][k-c-1]%mod);
}
adj(ans+=f[n][0]-mod);
}
printf("%d\n",1ll*ans*fac[n]%mod);
return 0;
}