题目
给定长为n(n<=1e5)的序列a,第i个数ai(1<=ai<=10)
求从n个数中选出奇数个数,使得选出的这些数的和为m(m<=1e6)的方案数
答案对998244353取模
思路来源
夏老师
题解
其实感觉有点子集反演,钦定popcount的位数的意思
注意到ai很小,只有不超过10,
所以先构造10个数的多项式,多项式内部由组合数去计算
由于需要区分奇数个还是偶数个,
每种数拆成两个多项式,分别对应奇/偶的情况,
然后ntt优化卷积即可,
a和b两种数合并的时候,
计0为偶数个,1为奇数个,
那么,合并后的数c,满足:
c0=a0*b0+a1*b1
c1=a1*b0+a0*b1
这样每合并两个数需要卷积四次,合并10个数需要36次,
复杂度,但显然跑不满
一个优化方式是,
可以用0*1+1*0求得奇数次的方案,用(0+1)*(0+1)求得总的方案,
二者作差得到偶数次的方案,即:
这样两种数合并的时候只需要卷积三次
代码1
//#include<bits/stdc++.h> #include<iostream> #include<cstdio> #include<vector> #include<cmath> #include<algorithm> #include<random> using namespace std; #define ll long long #define ull unsigned ll const int N = 1<<20, P = 998244353; const int Primitive_root = 3;//先用Get_root求出来原根然后当const用 struct Z{ int x; Z(const int _x=0):x(_x){} Z operator +(const Z &r)const{ return x+r.x<P?x+r.x:x+r.x-P;} Z operator -(const Z &r)const{ return x<r.x?x-r.x+P:x-r.x;} Z operator -()const{ return x?P-x:0;} Z operator *(const Z &r)const{ return static_cast<ull>(x)*r.x%P;} Z operator +=(const Z &r){ return x=x+r.x<P?x+r.x:x+r.x-P, *this;} Z operator -=(const Z &r){ return x=x<r.x?x-r.x+P:x-r.x, *this;} Z operator *=(const Z &r){ return x=static_cast<ull>(x)*r.x%P, *this;} friend Z Pow(Z, int); pair<Z,Z> Mul(pair<Z,Z> x, pair<Z,Z> y, Z f)const{ return make_pair( x.first*y.first+x.second*y.second*f, x.second*y.first+x.first*y.second ); } Z Quadratic_residue()const{ if(x<=1) return x; if(Pow((Z)x, (P-1)/2).x!=1) return -1; Z y, f; mt19937 rng(20030226); do y=rng()%(x-1)+1; while(Pow(f=y*y-x, (P-1)/2).x==1); pair<Z,Z> ans=make_pair(1, 0), t=make_pair(y, 1); for(int i=(P+1)/2; i; i>>=1, t=Mul(t, t, f)) if(i&1) ans=Mul(ans, t, f); return min(ans.first.x, P-ans.first.x); } }; Z Pow(Z x, int y=P-2){ Z ans=1; for(; y; y>>=1, x=x*x) if(y&1) ans=ans*x; return ans; } namespace Poly{ Z w[N<<1]; Z Inv[N]; vector<Z> ans; vector<vector<Z>> p; ull F[N]; int Get_root(){ static int pr[N],cnt; int n=P-1,sz=(int)(sqrt(n)),root=-1; for(int i=2;i<=sz;++i){if(n%i==0)pr[cnt++]=i;while(n%i==0)n/=i;} if(n>1)pr[cnt++]=n; for(int i=1;i<P;++i){ if(Pow((Z)i,P-1).x==1){ bool fl=true; for(int j=0;j<cnt;++j){ if(Pow(i,(P-1)/pr[j]).x==1){ fl=false;break; } } if(fl){root=i;break;} } } return root; } void Init(){ //printf("root:%d ",Primitive_root=Get_root()); 先求出来原根然后当const用 for(int i=1; i<N; i<<=1){ w[i]=1; Z t=Pow((Z)Primitive_root, (P-1)/i/2); for(int j=1; j<i; ++j) w[i+j]=w[i+j-1]*t;//这里w开N应该就可,忽略最后一步? } Inv[1]=1; for(int i=2; i<N; ++i) Inv[i]=Inv[P%i]*(P-P/i); } int Get(int x){ int n=1; while(n<=x) n<<=1; return n;} int Mod(int x){ return x<P?x:x-P;} void DFT(vector<Z> &f, int n){ if((int)f.size()!=n) f.resize(n); for(int i=0, j=0; i<n; ++i){ F[i]=f[j].x; for(int k=n>>1; (j^=k)<k; k>>=1); } if(n<=4){ for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){ Z *W=w+i; ull *F0=F+j, *F1=F+j+i; for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){ ull t=(*F1)*(W->x)%P; (*F1)=*F0+P-t, (*F0)+=t; } } } else{ for(int j=0; j<n; j+=2){ int t=F[j+1]; F[j+1]=Mod(F[j]+P-t), F[j]=Mod(F[j]+t); } for(int j=0; j<n; j+=4){ int t0=F[j+2], t1=F[j+3]*w[3].x%P; F[j+2]=F[j]+P-t0, F[j]+=t0; F[j+3]=F[j+1]+P-t1, F[j+1]+=t1; } for(int i=4; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){ Z *W=w+i; ull *F0=F+j, *F1=F+j+i; for(int k=j; k<j+i; k+=4, W+=4, F0+=4, F1+=4){ int t0=(W->x)**F1%P; int t1=(W+1)->x**(F1+1)%P; int t2=(W+2)->x**(F1+2)%P; int t3=(W+3)->x**(F1+3)%P; *F1=*F0+P-t0, *F0+=t0; *(F1+1)=*(F0+1)+P-t1, *(F0+1)+=t1; *(F1+2)=*(F0+2)+P-t2, *(F0+2)+=t2; *(F1+3)=*(F0+3)+P-t3, *(F0+3)+=t3; } } } for(int i=0; i<n; ++i) f[i]=F[i]%P; } void IDFT(vector<Z> &f, int n){ f.resize(n), reverse(f.begin()+1, f.end()), DFT(f, n); Z I=1; for(int i=1; i<n; i<<=1) I*=(P+1)/2; for(int i=0; i<n; ++i) f[i]*=I; } vector<Z> operator +(const vector<Z> &f, const vector<Z> &g){ vector<Z> ans=f; ans.resize(max(f.size(), g.size())); for(int i=0; i<(int)g.size(); ++i) ans[i]+=g[i]; return ans; } vector<Z> operator *(const vector<Z> &f, const vector<Z> &g){ static vector<Z> F, G; F=f, G=g; int p=Get(f.size()+g.size()-2); DFT(F, p), DFT(G, p); for(int i=0; i<p; ++i) F[i]*=G[i]; IDFT(F, p); return F.resize(f.size()+g.size()-1), F; } vector<Z> operator *(const vector<Z> &f, Z g){ vector<Z> ans=f; for(Z &i:ans) i*=g; return ans; } } using namespace Poly; Z fac[N],ifac[N]; vector<Z>a[11][2],tmp[2]; int n,m,v,cnt[11],c; void init(int n){ fac[0]=1; for(int i=1;i<=n;++i){ fac[i]=fac[i-1]*i; } ifac[n]=Pow(fac[n]); for(int i=n;i;--i){ ifac[i-1]=ifac[i]*i; } } Z C(int x,int y){ if(x<0 || y<0 || x<y)return 0; return fac[x]*ifac[y]*ifac[x-y]; } int main(){ Init(); init(N-1); scanf("%d%d",&n,&m); for(int i=1;i<=n;++i){ scanf("%d",&v); cnt[v]++; } for(int i=1;i<=10;++i){ if(cnt[i]){ ++c; int w=cnt[i]; a[c][w&1].resize(w*i+1); a[c][w&1^1].resize((w-1)*i+1); for(int j=0;j<=w;++j){ //printf("i:%d w:%d j:%d c:%d sz0:%d sz1:%d ",i,w,j,C(w,j),a[c][0].size(),a[c][1].size()); a[c][j&1][j*i]=C(w,j); } } } // for(int i=1;i<=c;++i){ // for(int j=0;j<2;++j){ // printf("i:%d j:%d ",i,j); // for(auto &v:a[i][j]){ // printf("%d ",v); // } // puts(""); // } // } for(int i=2;i<=c;++i){ tmp[0]=a[1][0]*a[i][0]+a[1][1]*a[i][1]; tmp[1]=a[1][0]*a[i][1]+a[1][1]*a[i][0]; tmp[0].swap(a[1][0]); tmp[1].swap(a[1][1]); } // for(int i=1;i<=1;++i){ // for(int j=0;j<2;++j){ // printf("i:%d j:%d ",i,j); // for(auto &v:a[i][j]){ // printf("%d ",v); // } // puts(""); // } // } a[1][1].resize(m+1); printf("%d ",a[1][1][m].x); return 0; }
代码2(小优化卷积次数)
少跑了100ms左右
//#include<bits/stdc++.h> #include<iostream> #include<cstdio> #include<vector> #include<cmath> #include<algorithm> #include<random> using namespace std; #define ll long long #define ull unsigned ll const int N = 1<<20, P = 998244353; const int Primitive_root = 3;//先用Get_root求出来原根然后当const用 struct Z{ int x; Z(const int _x=0):x(_x){} Z operator +(const Z &r)const{ return x+r.x<P?x+r.x:x+r.x-P;} Z operator -(const Z &r)const{ return x<r.x?x-r.x+P:x-r.x;} Z operator -()const{ return x?P-x:0;} Z operator *(const Z &r)const{ return static_cast<ull>(x)*r.x%P;} Z operator +=(const Z &r){ return x=x+r.x<P?x+r.x:x+r.x-P, *this;} Z operator -=(const Z &r){ return x=x<r.x?x-r.x+P:x-r.x, *this;} Z operator *=(const Z &r){ return x=static_cast<ull>(x)*r.x%P, *this;} friend Z Pow(Z, int); pair<Z,Z> Mul(pair<Z,Z> x, pair<Z,Z> y, Z f)const{ return make_pair( x.first*y.first+x.second*y.second*f, x.second*y.first+x.first*y.second ); } Z Quadratic_residue()const{ if(x<=1) return x; if(Pow((Z)x, (P-1)/2).x!=1) return -1; Z y, f; mt19937 rng(20030226); do y=rng()%(x-1)+1; while(Pow(f=y*y-x, (P-1)/2).x==1); pair<Z,Z> ans=make_pair(1, 0), t=make_pair(y, 1); for(int i=(P+1)/2; i; i>>=1, t=Mul(t, t, f)) if(i&1) ans=Mul(ans, t, f); return min(ans.first.x, P-ans.first.x); } }; Z Pow(Z x, int y=P-2){ Z ans=1; for(; y; y>>=1, x=x*x) if(y&1) ans=ans*x; return ans; } namespace Poly{ Z w[N<<1]; Z Inv[N]; vector<Z> ans; vector<vector<Z>> p; ull F[N]; int Get_root(){ static int pr[N],cnt; int n=P-1,sz=(int)(sqrt(n)),root=-1; for(int i=2;i<=sz;++i){if(n%i==0)pr[cnt++]=i;while(n%i==0)n/=i;} if(n>1)pr[cnt++]=n; for(int i=1;i<P;++i){ if(Pow((Z)i,P-1).x==1){ bool fl=true; for(int j=0;j<cnt;++j){ if(Pow(i,(P-1)/pr[j]).x==1){ fl=false;break; } } if(fl){root=i;break;} } } return root; } void Init(){ //printf("root:%d ",Primitive_root=Get_root()); 先求出来原根然后当const用 for(int i=1; i<N; i<<=1){ w[i]=1; Z t=Pow((Z)Primitive_root, (P-1)/i/2); for(int j=1; j<i; ++j) w[i+j]=w[i+j-1]*t;//这里w开N应该就可,忽略最后一步? } Inv[1]=1; for(int i=2; i<N; ++i) Inv[i]=Inv[P%i]*(P-P/i); } int Get(int x){ int n=1; while(n<=x) n<<=1; return n;} int Mod(int x){ return x<P?x:x-P;} void DFT(vector<Z> &f, int n){ if((int)f.size()!=n) f.resize(n); for(int i=0, j=0; i<n; ++i){ F[i]=f[j].x; for(int k=n>>1; (j^=k)<k; k>>=1); } if(n<=4){ for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){ Z *W=w+i; ull *F0=F+j, *F1=F+j+i; for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){ ull t=(*F1)*(W->x)%P; (*F1)=*F0+P-t, (*F0)+=t; } } } else{ for(int j=0; j<n; j+=2){ int t=F[j+1]; F[j+1]=Mod(F[j]+P-t), F[j]=Mod(F[j]+t); } for(int j=0; j<n; j+=4){ int t0=F[j+2], t1=F[j+3]*w[3].x%P; F[j+2]=F[j]+P-t0, F[j]+=t0; F[j+3]=F[j+1]+P-t1, F[j+1]+=t1; } for(int i=4; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){ Z *W=w+i; ull *F0=F+j, *F1=F+j+i; for(int k=j; k<j+i; k+=4, W+=4, F0+=4, F1+=4){ int t0=(W->x)**F1%P; int t1=(W+1)->x**(F1+1)%P; int t2=(W+2)->x**(F1+2)%P; int t3=(W+3)->x**(F1+3)%P; *F1=*F0+P-t0, *F0+=t0; *(F1+1)=*(F0+1)+P-t1, *(F0+1)+=t1; *(F1+2)=*(F0+2)+P-t2, *(F0+2)+=t2; *(F1+3)=*(F0+3)+P-t3, *(F0+3)+=t3; } } } for(int i=0; i<n; ++i) f[i]=F[i]%P; } void IDFT(vector<Z> &f, int n){ f.resize(n), reverse(f.begin()+1, f.end()), DFT(f, n); Z I=1; for(int i=1; i<n; i<<=1) I*=(P+1)/2; for(int i=0; i<n; ++i) f[i]*=I; } vector<Z> operator +(const vector<Z> &f, const vector<Z> &g){ vector<Z> ans=f; ans.resize(max(f.size(), g.size())); for(int i=0; i<(int)g.size(); ++i) ans[i]+=g[i]; return ans; } vector<Z> operator -(const vector<Z> &f, const vector<Z> &g){ vector<Z> ans=f; ans.resize(max(f.size(), g.size())); for(int i=0; i<(int)g.size(); ++i) ans[i]-=g[i]; return ans; } vector<Z> operator *(const vector<Z> &f, const vector<Z> &g){ static vector<Z> F, G; F=f, G=g; int p=Get(f.size()+g.size()-2); DFT(F, p), DFT(G, p); for(int i=0; i<p; ++i) F[i]*=G[i]; IDFT(F, p); return F.resize(f.size()+g.size()-1), F; } vector<Z> operator *(const vector<Z> &f, Z g){ vector<Z> ans=f; for(Z &i:ans) i*=g; return ans; } } using namespace Poly; Z fac[N],ifac[N]; vector<Z>a[11][2],tmp[2]; int n,m,v,cnt[11],c; void init(int n){ fac[0]=1; for(int i=1;i<=n;++i){ fac[i]=fac[i-1]*i; } ifac[n]=Pow(fac[n]); for(int i=n;i;--i){ ifac[i-1]=ifac[i]*i; } } Z C(int x,int y){ if(x<0 || y<0 || x<y)return 0; return fac[x]*ifac[y]*ifac[x-y]; } int main(){ Init(); init(N-1); scanf("%d%d",&n,&m); for(int i=1;i<=n;++i){ scanf("%d",&v); cnt[v]++; } for(int i=1;i<=10;++i){ if(cnt[i]){ ++c; int w=cnt[i]; a[c][w&1].resize(w*i+1); a[c][w&1^1].resize((w-1)*i+1); for(int j=0;j<=w;++j){ //printf("i:%d w:%d j:%d c:%d sz0:%d sz1:%d ",i,w,j,C(w,j),a[c][0].size(),a[c][1].size()); a[c][j&1][j*i]=C(w,j); } } } // for(int i=1;i<=c;++i){ // for(int j=0;j<2;++j){ // printf("i:%d j:%d ",i,j); // for(auto &v:a[i][j]){ // printf("%d ",v); // } // puts(""); // } // } for(int i=2;i<=c;++i){ tmp[0]=(a[1][0]+a[1][1])*(a[i][0]+a[i][1]); tmp[1]=a[1][0]*a[i][1]+a[1][1]*a[i][0]; tmp[0]=tmp[0]-tmp[1]; tmp[0].swap(a[1][0]); tmp[1].swap(a[1][1]); } // for(int i=1;i<=1;++i){ // for(int j=0;j<2;++j){ // printf("i:%d j:%d ",i,j); // for(auto &v:a[i][j]){ // printf("%d ",v); // } // puts(""); // } // } a[1][1].resize(m+1); printf("%d ",a[1][1][m].x); return 0; }