1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
| #include<iostream> #include<cstdio> using namespace std; const int N=1.5e5+10,K=30,mod=998244353; template<class T> inline void read(T &x) { x=0;int f=0; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); if(f)x=-x; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) { read(x),read(x1...); } int n,k; int tot=1,ch[N<<5][2],si[N<<5]; int p[N]; inline int adj(int x){return x>=mod?x-mod:x;} inline void insert(int x) { int u=1;si[u]++; for(int i=K;~i;i--) { bool v=x>>i&1; if(!ch[u][v])ch[u][v]=++tot; si[u=ch[u][v]]++; } } int solve(int x=1,int y=1,int i=30) { if(!x||!y)return p[si[x]+si[y]]; if(!(x^y)) { if(!~i)return p[si[x]]; if(k>>i&1)return solve(ch[x][0],ch[x][1],i-1); return adj(solve(ch[x][0],ch[x][0],i-1)+solve(ch[x][1],ch[x][1],i-1)-1); } if(!~i)return p[si[x]+si[y]]; if(k>>i&1)return 1ll*solve(ch[x][0],ch[y][1],i-1)*solve(ch[x][1],ch[y][0],i-1)%mod; return adj(adj(solve(ch[x][0],ch[y][0],i-1)+solve(ch[x][1],ch[y][1],i-1)-1)+ adj(1ll*(p[si[ch[x][0]]]-1)*(p[si[ch[x][1]]]-1)%mod+1ll*(p[si[ch[y][0]]]-1)*(p[si[ch[y][1]]]-1)%mod)); } int main() { read(n,k); for(int i=1;i<=n;i++) { int x; read(x); insert(x); } p[0]=1; for(int i=1;i<=n;i++)p[i]=adj(p[i-1]<<1); printf("%d\n",solve()-1); return 0; }
|