Liveddd's Blog

愛にできることはまだあるかい

二项式反演

经典有用的反演。

1.定义

有这么几个比较重要的式子:

这些正是我们进行二项式反演的基础式子,似乎能看到一些容斥的影子。但是公式的运用是浅显的,更加重要的是关于模型的构建。

另外地,利用上式暴力求 $g(1),g(2),\dots,g(n)$ 是 $\mathcal O(n^2)$ 的。但是可以做到更优。我们有:

设 $F(x),G(x)$ 为 $f(n),g(n)$ 的 EGF。并且有 $\displaystyle e^x=\sum_{n>0}\frac{x_n}{n!}$。我们辨认出卷积的形式,得到:

利用 $(5)$ 式我们可以用 NTT 加速卷积求 $g(1),g(2),\dots,g(n)$,时间复杂度为 $\mathcal O(n\log n)$。

2.例题

1.P4859 已经没有什么好害怕的了

相当求使 $a>b$ 恰好有 $k$ 对的方案数。注意到 ”恰好“,联想到转化为 “至少” 来求解(”至少“ 限制更宽松,在多数情况下是相对更好求的)。设计 $f(n)$ 表示至少 $n$ 个满足 $a>b$ 的方案数, $g(n)$ 表示恰好 $n$ 个满足 $a>b$ 的方案数。我们要求的就是 $g(k)$。对于这两个函数,我们有关系式:

使用式子 $(4)$ 进行二项式反演:

考虑设 $f_{i,j}$ 表示前 $i$ 个之中选出 $j$ 组满足 $a>b$,相当于钦定了 $j$ 组 $a>b$,剩下的可以随便选。先将 $A,B$ 排序,处理出 $b_j<a_i$ 的个数 $cnt_i$。并且仍然按照此顺序进行 DP,容易得到转移方程:

而我们要求的 $f(i)=(n-i)! f_{n,i}$。然后带入二项式反演的式子即可得到 $g(k)$。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=2e3+10,mod=1e9+9;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,k;
int a[N],b[N],cnt[N];
int fac[N],ifac[N];
int f[N];
inline int adj(int x){return x>=mod?x-mod:x;}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
inline int C(int n,int m)
{
if(m>n)return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
int main()
{
read(n,k);
if(n+k&1)return puts("0"),0;
k=n+k>>1;
for(int i=1;i<=n;i++)read(a[i]);
for(int i=1;i<=n;i++)read(b[i]);
sort(a+1,a+1+n),sort(b+1,b+1+n);
int p=0;
for(int i=1;i<=n;i++)
{
while(p+1<=n&&a[i]>b[p+1])p++;
cnt[i]=p;
}
fac[0]=ifac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=qpow(fac[i]);
f[0]=1;
for(int i=1;i<=n;i++)
for(int j=min(i,cnt[i]);j;j--)
f[j]=adj(f[j]+1ll*(cnt[i]-j+1)*f[j-1]%mod);
int ans=0;
for(int i=k;i<=n;i++)
{
int res=(i-k&1)?mod-1ll*f[i]*fac[n-i]%mod*C(i,k)%mod:
1ll*f[i]*fac[n-i]%mod*C(i,k)%mod;
ans=adj(ans+res);
}
printf("%d\n",ans);
return 0;
}

2.P6478 [NOI Online #2 提高组]游戏

再次注意到 ”恰好“ 的限制条件。考虑像上一道题一样转化为 ”至少“。考虑设计 $f(i)$ 表示至少 $i$ 个非平局回合, $g(i)$ 表示恰好 $i$ 个非平局回合。还是用上面的式子进行二项式反演。考虑用 DP 求出 $f(i)$。

考虑树上背包,设计 $f_{x,i}$ 表示以 $x$ 为根的子树中钦定了 $i$ 个非平局回合,容易得到转移方程:

$f(i)$ 即为 $(n-i)!\cdot f_{1,i}$,带入公式即可求出 $g(i)$。总时间复杂度 $\mathcal O(n^2)$。

需要注意的是背包转移时的细节:转移形式是加法卷积,对于初始的 $f’_{x,i}$ 应当为 $f_{x,i}\times f_{y,0}$,而非 $f_{x,i}\times (f_{y,0}+1)$,因为我们必须选完 $y$ 子树里的点(只不过没有非平局回合)。更为保险的做法是开一个副本作为计算当前答案,算完再复制一遍。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<iostream>
#include<cstdio>
using namespace std;
const int N=5e3+10,mod=998244353;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n;
int tot,head[N],ver[N<<1],ne[N<<1];
int col[N],si[N][2];
int f[N][N],fac[N],ifac[N];
inline int adj(int x){return (x>=mod)?x-mod:x;}
inline void add(int u,int v)
{
ver[++tot]=v;
ne[tot]=head[u];
head[u]=tot;
}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
inline int C(int n,int m)
{
if(m>n)return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
void dfs(int x,int fa)
{
f[x][0]=1;
for(int i=head[x];i;i=ne[i])
{
int y=ver[i];
if(y==fa)continue;
dfs(y,x);
int six=min(si[x][0],si[x][1]),siy=min(si[y][0],si[y][1]);
for(int j=six+siy;~j;j--)
for(int k=max(j-six,0);k<=min(siy,j);k++)
{
if(!k)f[x][j]=1ll*f[x][j]*f[y][0]%mod;
else f[x][j]=adj(f[x][j]+1ll*f[x][j-k]*f[y][k]%mod);
}

si[x][0]+=si[y][0],si[x][1]+=si[y][1];
}
si[x][col[x]]++;
int six=min(si[x][0],si[x][1]);
for(int i=six;i;i--)
f[x][i]=adj(f[x][i]+1ll*(si[x][col[x]^1]-(i-1))*f[x][i-1]%mod);
}
int main()
{
read(n);
for(int i=1;i<=n;i++)
scanf("%1d",&col[i]);
for(int i=1;i<n;i++)
{
int u,v;
read(u,v);
add(u,v),add(v,u);
}
n>>=1;
dfs(1,0);
fac[0]=ifac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=qpow(fac[i]);
for(int i=0;i<=n;i++)
{
int ans=0;
for(int j=i;j<=n;j++)
{
int res=(j-i&1)?mod-1ll*f[1][j]*fac[n-j]%mod*C(j,i)%mod:
1ll*f[1][j]*fac[n-j]%mod*C(j,i)%mod;
ans=adj(ans+res);
}
printf("%d\n",ans);
}
return 0;
}

通过上面两个例题,容易发现很多情况下我们能够通过二项式反演将 ”恰好“ 的限制转化为 ”至少“ 的限制,再进一步求解。这也是解决这一类题的基本思路。换个角度来看,是一个当函数与集合具体有哪些元素无关,而与其中元素个数有关的情况下,进行的一个容斥。这一点能很好的将二者而这联系起来。

3.CF1707D Partial Virtual Trees

先考虑直接 DP,但发现并不好做。为什么?因为每次操作必须删除一个点,但是合并子树时只要有一棵子树在这一次操作内进行删点就行了。

于是考虑如何去掉这个限制。设计 $f(i)$ 表示至多进行 $i$ 次操作将树删空的方案数,或者换一种说法,进行的 $i$ 次操作其中有操作可以不删点,但是最终仍需要删空。设计 $g(i)$ 表示恰好进行 $i$ 次操作将树删空的方案数。通过枚举进行删点的 $i$ 次操作,我们可以得到:

这恰是二项式反演的形式!我们使用式子 $(2)$ 得到:

还是考虑 DP 求出来 $f(i)$,设计 $f_{x,i}$ 表示将子树 $x$ 内操作 $i$ 次(可以不删点)之后为空的方案数。考虑两种情况:1.将所有子节点删完后,删除 $x$;2.将所有子节点删到只有 $1$ 棵子树 $y$ 中有点,再删去 $x$,再将剩下的节点删完。于是有转移方程:

前缀和一下辅助转移,这一部分是 $\mathcal O(n^2)$ 的。所以总的时间复杂度是 $\mathcal O(n^2)$。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
using namespace std;
const int N=2e3+10;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,mod;
int fac[N],ifac[N];
int f[N][N];
int sum[N][N],pre[N][N],suf[N][N];

vector<int>e[N];
inline int adj(int x){return x>=mod?x-mod:x;}
inline void add(int u,int v)
{
e[u].push_back(v);
}
inline qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
inline int C(int n,int m)
{
if(m>n)return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
void dfs(int x,int fa)
{
int si=e[x].size();
for(int i=0;i<si;i++)
{
int y=e[x][i];
if(y==fa)continue;
dfs(y,x);
}
for(int i=0;i<si;i++)
{
int y=e[x][i];
if(y==fa)
{
for(int j=1;j<=n;j++)pre[i][j]=suf[i][j]=1;
continue;
}
for(int j=1;j<=n;j++)pre[i][j]=suf[i][j]=sum[y][j];
}
for(int i=1;i<si;i++)
for(int j=1;j<=n;j++)
pre[i][j]=1ll*pre[i-1][j]*pre[i][j]%mod;
for(int i=si-2;i>=0;i--)
for(int j=1;j<=n;j++)
suf[i][j]=1ll*suf[i+1][j]*suf[i][j]%mod;
for(int i=1;i<=n;i++)f[x][i]=adj(f[x][i]+pre[si-1][i]);
if(x!=1)
{
for(int i=0;i<si;i++)
{
int y=e[x][i];
if(y==fa)continue;
int res=0;
for(int j=1;j<=n;j++)
{
f[x][j]=adj(f[x][j]+1ll*f[y][j]*res%mod);
res=adj(res+1ll*((i)?pre[i-1][j]:1)*((i+1<si)?suf[i+1][j]:1)%mod);
}
}
for(int i=1;i<=n;i++)sum[x][i]=adj(f[x][i]+sum[x][i-1]);
}
}
int main()
{
read(n,mod);
for(int i=1;i<n;i++)
{
int u,v;
read(u,v);
add(u,v),add(v,u);
}
fac[0]=ifac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=qpow(fac[i]);
dfs(1,0);
for(int i=1;i<n;i++)
{
int ans=0;
for(int j=1;j<=i;j++)
{
int res=(i-j&1)?mod-1ll*f[1][j]*C(i,j)%mod:
1ll*f[1][j]*C(i,j)%mod;
ans=adj(ans+res);
}
printf("%d ",ans);
}

return 0;
}