1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| #include<iostream> #include<cstdio> #include<cstring> #include<vector> using namespace std; const int N=5e5+10,mod=998244353; template<class T> inline void read(T &x) { x=0;int f=0; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); if(f)x=~x+1; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) { read(x),read(x1...); } int n,k,m; int l[N],r[N],x[N]; int cnt[N],f[N]; int q[N]; inline int adj(int x){return x>=mod?x-mod:x;} inline int solve(int bit) { memset(cnt,0,sizeof(cnt)); memset(q,0,sizeof(q)); memset(f,0,sizeof(f)); for(int i=1;i<=m;i++) { if(x[i]>>bit&1)cnt[l[i]]++,cnt[r[i]+1]--; else q[r[i]+1]=max(q[r[i]+1],l[i]); } int l=0,sum=f[0]=1; for(int i=1;i<=n+1;i++) { cnt[i]+=cnt[i-1]; while(l<q[i])sum=adj(sum-f[l]+mod),f[l]=0,l++; f[i]=cnt[i]?0:sum; sum=adj(sum+f[i]); } return f[n+1]; } int main() { read(n,k,m); for(int i=1;i<=m;i++)read(l[i],r[i],x[i]); int ans=1; for(int i=0;i<k;i++)ans=1ll*ans*solve(i)%mod; printf("%d\n",ans); return 0; }
|