神仙的二项式反演,可以转为树上拓扑序,或者直接进行组合意义考虑。
注意到恰好的限制,这很容斥。于是直接考虑直接进行二项式反演。设 $f(i)$ 表示恰好有 $i$ 个极大的数的概率,$g(i)$ 表示钦定 $i$ 个极大的数的概率,得到:
先钦定出 $k$ 个位置,并且从大到小填入,这样的方案数显然为 $n^{\underline k}m^{\underline k}l^{\underline k}$。继续,有两种求 $g(i)$ 的方法。
转化为图论
先简单考虑二维的情况。那么 $k$ 个点的限制构成若干偏序关系,我们尝试根据这些关系建图。首先这 $k$ 个点之间根据偏序关系连边。剩下的 $nml-(n-k)(m-k)(l-k)$ 个点按照与 $k$ 个极大值之间的限制连边即可。于是整个限制构成一个一棵树,对其赋权值等价于给出一种合法的拓扑序。对于大小为 $n$ 的树,设 $\operatorname{size}(x)$ 表示子树 $x$ 的大小,那么合法的拓扑序为 $\displaystyle \frac{n!}{\prod_{x} \operatorname{size}(x)}$。除了 $k$ 个极大值之外的点子树大小为 $1$,从小到大第 $i$ 个极大值对应的子树大小为 $nml-(n-i)(m-i)(l-i)$,得到:
组合意义
其实思考过程与上面一种本质上是一样的,不再赘述。最终得到 $g(i)$ 的式子是一样的。
最终时间复杂度为 $\mathcal O(n)$。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| #include<iostream> #include<cstdio> using namespace std; const int N=5e6+10,mod=998244353; template<class T> inline void read(T &x) { x=0;int f=0; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); if(f)x=~x+1; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) { read(x),read(x1...); } int T,n,m,l,k; int fac[N],ifac[N]; int inv[N]; inline void adj(int &x){x+=x>>31&mod;} inline int qpow(int x,int k=mod-2) { int res=1; while(k) { if(k&1)res=1ll*res*x%mod; x=1ll*x*x%mod; k>>=1; } return res; } inline void init(int n) { for(int i=fac[0]=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod; ifac[n]=qpow(fac[n]); for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod; } inline int C(int n,int m) { if(m>n)return 0; return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod; } inline int calc(int i){return 1ll*(n-i)*(m-i)%mod*(l-i)%mod;} inline void solve() { read(n,m,l,k); if(n>m)swap(n,m); if(n>l)swap(n,l); int all=calc(0),tot=1; for(int i=1;i<=n;i++)tot=1ll*tot*(all-calc(i)+mod)%mod; inv[n]=qpow(tot); for(int i=n;i;i--)inv[i-1]=1ll*inv[i]*(all-calc(i)+mod)%mod; int ans=0,mul=1; for(int i=0;i<k;i++)mul=1ll*mul*calc(i)%mod; for(int i=k;i<=n;i++) { int res=1ll*C(i,k)*mul%mod*inv[i]%mod; adj(ans+=(i-k&1)?-res:res-mod); mul=1ll*mul*calc(i)%mod; } printf("%d\n",ans); } int main() { init(N-10); read(T); while(T--)solve(); return 0; }
|