Liveddd's Blog

愛にできることはまだあるかい

二项式反演

经典有用的反演。

1.定义

有这么几个比较重要的式子:

(1)f(n)=i=0n(1)i(ni)g(i)g(n)=i=0n(1)i(ni)f(i)(2)f(n)=i=0n(ni)g(i)g(n)=i=0n(1)ni(ni)f(i)(3)f(n)=i=nm(1)i(in)g(i)g(n)=i=nm(1)i(in)f(i)(4)f(n)=i=nm(in)g(i)g(n)=i=nm(1)in(in)f(i)

这些正是我们进行二项式反演的基础式子,似乎能看到一些容斥的影子。但是公式的运用是浅显的,更加重要的是关于模型的构建。

另外地,利用上式暴力求 g(1),g(2),,g(n)O(n2) 的。但是可以做到更优。我们有:

f(n)=i=0n(ni)g(i)f(n)n!=i=0n1(ni)!g(n)i!

F(x),G(x)f(n),g(n) 的 EGF。并且有 ex=n>0xnn!。我们辨认出卷积的形式,得到:

(5)F(x)=G(x)exG(x)=F(x)ex

利用 (5) 式我们可以用 NTT 加速卷积求 g(1),g(2),,g(n),时间复杂度为 O(nlogn)

2.例题

1.P4859 已经没有什么好害怕的了

相当求使 a>b 恰好有 k 对的方案数。注意到 ”恰好“,联想到转化为 “至少” 来求解(”至少“ 限制更宽松,在多数情况下是相对更好求的)。设计 f(n) 表示至少 n 个满足 a>b 的方案数, g(n) 表示恰好 n 个满足 a>b 的方案数。我们要求的就是 g(k)。对于这两个函数,我们有关系式:

f(k)=i=kn(ik)g(i)

使用式子 (4) 进行二项式反演:

g(k)=i=kn(1)ik(ik)f(i)

考虑设 fi,j 表示前 i 个之中选出 j 组满足 a>b,相当于钦定了 ja>b,剩下的可以随便选。先将 A,B 排序,处理出 bj<ai 的个数 cnti。并且仍然按照此顺序进行 DP,容易得到转移方程:

fi,j=fi1,j+(cntij+1)×fi1,j1

而我们要求的 f(i)=(ni)!fn,i。然后带入二项式反演的式子即可得到 g(k)

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=2e3+10,mod=1e9+9;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,k;
int a[N],b[N],cnt[N];
int fac[N],ifac[N];
int f[N];
inline int adj(int x){return x>=mod?x-mod:x;}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
inline int C(int n,int m)
{
if(m>n)return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
int main()
{
read(n,k);
if(n+k&1)return puts("0"),0;
k=n+k>>1;
for(int i=1;i<=n;i++)read(a[i]);
for(int i=1;i<=n;i++)read(b[i]);
sort(a+1,a+1+n),sort(b+1,b+1+n);
int p=0;
for(int i=1;i<=n;i++)
{
while(p+1<=n&&a[i]>b[p+1])p++;
cnt[i]=p;
}
fac[0]=ifac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=qpow(fac[i]);
f[0]=1;
for(int i=1;i<=n;i++)
for(int j=min(i,cnt[i]);j;j--)
f[j]=adj(f[j]+1ll*(cnt[i]-j+1)*f[j-1]%mod);
int ans=0;
for(int i=k;i<=n;i++)
{
int res=(i-k&1)?mod-1ll*f[i]*fac[n-i]%mod*C(i,k)%mod:
1ll*f[i]*fac[n-i]%mod*C(i,k)%mod;
ans=adj(ans+res);
}
printf("%d\n",ans);
return 0;
}

2.P6478 [NOI Online #2 提高组]游戏

再次注意到 ”恰好“ 的限制条件。考虑像上一道题一样转化为 ”至少“。考虑设计 f(i) 表示至少 i 个非平局回合, g(i) 表示恰好 i 个非平局回合。还是用上面的式子进行二项式反演。考虑用 DP 求出 f(i)

考虑树上背包,设计 fx,i 表示以 x 为根的子树中钦定了 i 个非平局回合,容易得到转移方程:

fx,ij=0ifx,ij×fy,jfx,i(size0/1[x](i1))×fx,i1

f(i) 即为 (ni)!f1,i,带入公式即可求出 g(i)。总时间复杂度 O(n2)

需要注意的是背包转移时的细节:转移形式是加法卷积,对于初始的 fx,i 应当为 fx,i×fy,0,而非 fx,i×(fy,0+1),因为我们必须选完 y 子树里的点(只不过没有非平局回合)。更为保险的做法是开一个副本作为计算当前答案,算完再复制一遍。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<iostream>
#include<cstdio>
using namespace std;
const int N=5e3+10,mod=998244353;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n;
int tot,head[N],ver[N<<1],ne[N<<1];
int col[N],si[N][2];
int f[N][N],fac[N],ifac[N];
inline int adj(int x){return (x>=mod)?x-mod:x;}
inline void add(int u,int v)
{
ver[++tot]=v;
ne[tot]=head[u];
head[u]=tot;
}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
inline int C(int n,int m)
{
if(m>n)return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
void dfs(int x,int fa)
{
f[x][0]=1;
for(int i=head[x];i;i=ne[i])
{
int y=ver[i];
if(y==fa)continue;
dfs(y,x);
int six=min(si[x][0],si[x][1]),siy=min(si[y][0],si[y][1]);
for(int j=six+siy;~j;j--)
for(int k=max(j-six,0);k<=min(siy,j);k++)
{
if(!k)f[x][j]=1ll*f[x][j]*f[y][0]%mod;
else f[x][j]=adj(f[x][j]+1ll*f[x][j-k]*f[y][k]%mod);
}

si[x][0]+=si[y][0],si[x][1]+=si[y][1];
}
si[x][col[x]]++;
int six=min(si[x][0],si[x][1]);
for(int i=six;i;i--)
f[x][i]=adj(f[x][i]+1ll*(si[x][col[x]^1]-(i-1))*f[x][i-1]%mod);
}
int main()
{
read(n);
for(int i=1;i<=n;i++)
scanf("%1d",&col[i]);
for(int i=1;i<n;i++)
{
int u,v;
read(u,v);
add(u,v),add(v,u);
}
n>>=1;
dfs(1,0);
fac[0]=ifac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=qpow(fac[i]);
for(int i=0;i<=n;i++)
{
int ans=0;
for(int j=i;j<=n;j++)
{
int res=(j-i&1)?mod-1ll*f[1][j]*fac[n-j]%mod*C(j,i)%mod:
1ll*f[1][j]*fac[n-j]%mod*C(j,i)%mod;
ans=adj(ans+res);
}
printf("%d\n",ans);
}
return 0;
}

通过上面两个例题,容易发现很多情况下我们能够通过二项式反演将 ”恰好“ 的限制转化为 ”至少“ 的限制,再进一步求解。这也是解决这一类题的基本思路。换个角度来看,是一个当函数与集合具体有哪些元素无关,而与其中元素个数有关的情况下,进行的一个容斥。这一点能很好的将二者而这联系起来。

3.CF1707D Partial Virtual Trees

先考虑直接 DP,但发现并不好做。为什么?因为每次操作必须删除一个点,但是合并子树时只要有一棵子树在这一次操作内进行删点就行了。

于是考虑如何去掉这个限制。设计 f(i) 表示至多进行 i 次操作将树删空的方案数,或者换一种说法,进行的 i 次操作其中有操作可以不删点,但是最终仍需要删空。设计 g(i) 表示恰好进行 i 次操作将树删空的方案数。通过枚举进行删点的 i 次操作,我们可以得到:

f(n)=i=0n(ni)g(i)

这恰是二项式反演的形式!我们使用式子 (2) 得到:

g(n)=i=0n(1)nif(i)

还是考虑 DP 求出来 f(i),设计 fx,i 表示将子树 x 内操作 i 次(可以不删点)之后为空的方案数。考虑两种情况:1.将所有子节点删完后,删除 x;2.将所有子节点删到只有 1 棵子树 y 中有点,再删去 x,再将剩下的节点删完。于是有转移方程:

fx,iyj=1ify,jfx,iyfy,i×k=1i1 tyj=1kft,j

前缀和一下辅助转移,这一部分是 O(n2) 的。所以总的时间复杂度是 O(n2)

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
using namespace std;
const int N=2e3+10;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,mod;
int fac[N],ifac[N];
int f[N][N];
int sum[N][N],pre[N][N],suf[N][N];

vector<int>e[N];
inline int adj(int x){return x>=mod?x-mod:x;}
inline void add(int u,int v)
{
e[u].push_back(v);
}
inline qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
inline int C(int n,int m)
{
if(m>n)return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
void dfs(int x,int fa)
{
int si=e[x].size();
for(int i=0;i<si;i++)
{
int y=e[x][i];
if(y==fa)continue;
dfs(y,x);
}
for(int i=0;i<si;i++)
{
int y=e[x][i];
if(y==fa)
{
for(int j=1;j<=n;j++)pre[i][j]=suf[i][j]=1;
continue;
}
for(int j=1;j<=n;j++)pre[i][j]=suf[i][j]=sum[y][j];
}
for(int i=1;i<si;i++)
for(int j=1;j<=n;j++)
pre[i][j]=1ll*pre[i-1][j]*pre[i][j]%mod;
for(int i=si-2;i>=0;i--)
for(int j=1;j<=n;j++)
suf[i][j]=1ll*suf[i+1][j]*suf[i][j]%mod;
for(int i=1;i<=n;i++)f[x][i]=adj(f[x][i]+pre[si-1][i]);
if(x!=1)
{
for(int i=0;i<si;i++)
{
int y=e[x][i];
if(y==fa)continue;
int res=0;
for(int j=1;j<=n;j++)
{
f[x][j]=adj(f[x][j]+1ll*f[y][j]*res%mod);
res=adj(res+1ll*((i)?pre[i-1][j]:1)*((i+1<si)?suf[i+1][j]:1)%mod);
}
}
for(int i=1;i<=n;i++)sum[x][i]=adj(f[x][i]+sum[x][i-1]);
}
}
int main()
{
read(n,mod);
for(int i=1;i<n;i++)
{
int u,v;
read(u,v);
add(u,v),add(v,u);
}
fac[0]=ifac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=qpow(fac[i]);
dfs(1,0);
for(int i=1;i<n;i++)
{
int ans=0;
for(int j=1;j<=i;j++)
{
int res=(i-j&1)?mod-1ll*f[1][j]*C(i,j)%mod:
1ll*f[1][j]*C(i,j)%mod;
ans=adj(ans+res);
}
printf("%d ",ans);
}

return 0;
}

Gitalk 加载中 ...