容斥经典好题。
感觉思路反而很自然。将满二叉树画出来,是这样的:
我们容易发现只需要保证,图上的 所有子树中编号最小值 都不为 $a_i$ 即可,因为只有子树中编号最小的人才会与 $1$ 交手。
考虑容斥,钦定子树集合 $S$ 的最小值都是 $a_i$,其余子树任意选。然后利用 $\displaystyle g(S)=\sum_{T\subseteq S} (-1)^{|T|+1} f(T)$ 进行容斥得到答案即可。
我们从大到小依次考虑 $a_i$,因为我们只关心子树中的最小值是否为 $a_i$,从大到小恰好可以满足当前枚举的 $a_i$ 是最小的。设计状态 $f_{i,S}$ 表示 选取到第 $i$,钦定集合为 $S$,转移很容易得到。DP时可以直接将容斥系数数算进去,再乘上组合数即可。时间复杂度 $\mathcal O(nm2^n)$。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
   | #include<iostream> #include<cstdio> #include<algorithm> using namespace std; const int N=17,M=1<<N,mod=1e9+7; template<class T> inline void read(T &x) {     x=0;int f=0;     char ch=getchar();     while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}     while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();     if(f)x=~x+1; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) {     read(x),read(x1...); } int n,m; int a[N]; int f[N][M]; int fac[M],ifac[M]; inline int adj(int x){return (x>=mod)?x-mod:x;} inline int qpow(int x,int j=mod-2) {     int res=1;     while(j)     {         if(j&1)res=1ll*res*x%mod;         x=1ll*x*x%mod;         j>>=1;     }     return res; } inline int C(int n,int m) {     if(m>n)return 0;     return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod; } inline void init(int n) {     fac[0]=1;     for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;     ifac[n]=qpow(fac[n]);     for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod; } int main() {     read(n,m);     init(1<<n);     for(int i=0;i<m;i++)read(a[i]);     f[m][0]=1;     for(int i=m-1;~i;i--)         for(int s=0;s<(1<<n);s++)         {             f[i][s]=adj(f[i][s]+f[i+1][s]);             for(int j=0;j<n;j++)                 if(!(s>>j&1))                     f[i][s|(1<<j)]=adj(f[i][s|(1<<j)]+1ll*C((1<<n)-a[i]-s,(1<<j)-1)*(mod-f[i+1][s])%mod);         }     int ans=0;     for(int s=0;s<(1<<n);s++)     {         int res=1;         for(int i=0;i<n;i++)if(s>>i&1)res=1ll*res*fac[1<<i]%mod;         ans=adj(ans+1ll*res*fac[(1<<n)-1-s]%mod*f[0][s]%mod);     }     printf("%d\n",1ll*ans*(1<<n)%mod);     return 0; }
   |