NEC Programming Contest 2022 (AtCoder Beginner Contest 267) Ex. Odd Sum(NTT)

题目

给定长为n(n<=1e5)的序列a,第i个数ai(1<=ai<=10)

求从n个数中选出奇数个数,使得选出的这些数的和为m(m<=1e6)的方案数

答案对998244353取模

思路来源

夏老师

题解

其实感觉有点子集反演,钦定popcount的位数的意思

注意到ai很小,只有不超过10,

所以先构造10个数的多项式,多项式内部由组合数去计算

由于需要区分奇数个还是偶数个,

每种数拆成两个多项式,分别对应奇/偶的情况,

然后ntt优化卷积即可,

a和b两种数合并的时候,

计0为偶数个,1为奇数个,

那么,合并后的数c,满足:

c0=a0*b0+a1*b1

c1=a1*b0+a0*b1

这样每合并两个数需要卷积四次,合并10个数需要36次,

复杂度O(m*logm*36),但显然跑不满

一个优化方式是,

可以用0*1+1*0求得奇数次的方案,用(0+1)*(0+1)求得总的方案,

二者作差得到偶数次的方案,即:

(0+1)*(0+1)-0*1-1*0

这样两种数合并的时候只需要卷积三次

代码1

//#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<vector>
#include<cmath>
#include<algorithm>
#include<random>
using namespace std;
#define ll long long
#define ull unsigned ll
const int N = 1<<20, P = 998244353;
const int Primitive_root = 3;//先用Get_root求出来原根然后当const用 
struct Z{
    int x;
    Z(const int _x=0):x(_x){}
    Z operator +(const Z &r)const{ return x+r.x<P?x+r.x:x+r.x-P;}
    Z operator -(const Z &r)const{ return x<r.x?x-r.x+P:x-r.x;}
    Z operator -()const{ return x?P-x:0;}
    Z operator *(const Z &r)const{ return static_cast<ull>(x)*r.x%P;}
    Z operator +=(const Z &r){ return x=x+r.x<P?x+r.x:x+r.x-P, *this;}
    Z operator -=(const Z &r){ return x=x<r.x?x-r.x+P:x-r.x, *this;}
    Z operator *=(const Z &r){ return x=static_cast<ull>(x)*r.x%P, *this;}
    friend Z Pow(Z, int);
    pair<Z,Z> Mul(pair<Z,Z> x, pair<Z,Z> y, Z f)const{
        return make_pair(
            x.first*y.first+x.second*y.second*f,
            x.second*y.first+x.first*y.second
        );
    }
    Z Quadratic_residue()const{
        if(x<=1) return x;
        if(Pow((Z)x, (P-1)/2).x!=1) return -1;
        Z y, f;
        mt19937 rng(20030226);
        do y=rng()%(x-1)+1; while(Pow(f=y*y-x, (P-1)/2).x==1);
        pair<Z,Z> ans=make_pair(1, 0), t=make_pair(y, 1);
        for(int i=(P+1)/2; i; i>>=1, t=Mul(t, t, f)) if(i&1) ans=Mul(ans, t, f);
        return min(ans.first.x, P-ans.first.x);
    }
};
Z Pow(Z x, int y=P-2){
    Z ans=1;
    for(; y; y>>=1, x=x*x) if(y&1) ans=ans*x;
    return ans;
}
namespace Poly{
    Z w[N<<1];
    Z Inv[N];
    vector<Z> ans;
    vector<vector<Z>> p;
    ull F[N];
    int Get_root(){
		static int pr[N],cnt;
		int n=P-1,sz=(int)(sqrt(n)),root=-1;
		for(int i=2;i<=sz;++i){if(n%i==0)pr[cnt++]=i;while(n%i==0)n/=i;}
		if(n>1)pr[cnt++]=n;
		for(int i=1;i<P;++i){
			if(Pow((Z)i,P-1).x==1){
				bool fl=true;
				for(int j=0;j<cnt;++j){
					if(Pow(i,(P-1)/pr[j]).x==1){
						fl=false;break;
					}
				}
				if(fl){root=i;break;}
			}
		}
		return root;
	}
    void Init(){
    	//printf("root:%d
",Primitive_root=Get_root()); 先求出来原根然后当const用 
        for(int i=1; i<N; i<<=1){
            w[i]=1;
            Z t=Pow((Z)Primitive_root, (P-1)/i/2);
            for(int j=1; j<i; ++j) w[i+j]=w[i+j-1]*t;//这里w开N应该就可,忽略最后一步?
        }
        Inv[1]=1;
        for(int i=2; i<N; ++i) Inv[i]=Inv[P%i]*(P-P/i);
    }
    int Get(int x){ int n=1; while(n<=x) n<<=1; return n;}
    int Mod(int x){ return x<P?x:x-P;}
    void DFT(vector<Z> &f, int n){
        if((int)f.size()!=n) f.resize(n);
        for(int i=0, j=0; i<n; ++i){
            F[i]=f[j].x;
            for(int k=n>>1; (j^=k)<k; k>>=1);
        }
        if(n<=4){
            for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
                Z *W=w+i;
                ull *F0=F+j, *F1=F+j+i;
                for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){
                    ull t=(*F1)*(W->x)%P;
                    (*F1)=*F0+P-t, (*F0)+=t;
                }
            }
        }
        else{
            for(int j=0; j<n; j+=2){
                int t=F[j+1];
                F[j+1]=Mod(F[j]+P-t), F[j]=Mod(F[j]+t);
            }
            for(int j=0; j<n; j+=4){
                int t0=F[j+2], t1=F[j+3]*w[3].x%P;
                F[j+2]=F[j]+P-t0, F[j]+=t0;
                F[j+3]=F[j+1]+P-t1, F[j+1]+=t1;
            }
            for(int i=4; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
                Z *W=w+i;
                ull *F0=F+j, *F1=F+j+i;
                for(int k=j; k<j+i; k+=4, W+=4, F0+=4, F1+=4){
                    int t0=(W->x)**F1%P;
                    int t1=(W+1)->x**(F1+1)%P;
                    int t2=(W+2)->x**(F1+2)%P;
                    int t3=(W+3)->x**(F1+3)%P;
                    *F1=*F0+P-t0, *F0+=t0;
                    *(F1+1)=*(F0+1)+P-t1, *(F0+1)+=t1;
                    *(F1+2)=*(F0+2)+P-t2, *(F0+2)+=t2;
                    *(F1+3)=*(F0+3)+P-t3, *(F0+3)+=t3;
                }
            }
        }
        for(int i=0; i<n; ++i) f[i]=F[i]%P;
    }
    void IDFT(vector<Z> &f, int n){
        f.resize(n), reverse(f.begin()+1, f.end()), DFT(f, n);
        Z I=1;
        for(int i=1; i<n; i<<=1) I*=(P+1)/2;
        for(int i=0; i<n; ++i) f[i]*=I;
    }
    vector<Z> operator +(const vector<Z> &f, const vector<Z> &g){
        vector<Z> ans=f;
        ans.resize(max(f.size(), g.size()));
        for(int i=0; i<(int)g.size(); ++i) ans[i]+=g[i];
        return ans;
    }
    vector<Z> operator *(const vector<Z> &f, const vector<Z> &g){
        static vector<Z> F, G;
        F=f, G=g;
        int p=Get(f.size()+g.size()-2);
        DFT(F, p), DFT(G, p);
        for(int i=0; i<p; ++i) F[i]*=G[i];
        IDFT(F, p);
        return F.resize(f.size()+g.size()-1), F;
    }
    vector<Z> operator *(const vector<Z> &f, Z g){
        vector<Z> ans=f;
        for(Z &i:ans) i*=g;
        return ans;
    }
}
using namespace Poly;
Z fac[N],ifac[N];
vector<Z>a[11][2],tmp[2];
int n,m,v,cnt[11],c;
void init(int n){
	fac[0]=1;
    for(int i=1;i<=n;++i){
		fac[i]=fac[i-1]*i;
    } 
	ifac[n]=Pow(fac[n]);
    for(int i=n;i;--i){
    	ifac[i-1]=ifac[i]*i;
    }
}
Z C(int x,int y){
    if(x<0 || y<0 || x<y)return 0;
    return fac[x]*ifac[y]*ifac[x-y];
}
int main(){
    Init();
    init(N-1);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i){
        scanf("%d",&v);
        cnt[v]++;
    }
    for(int i=1;i<=10;++i){
        if(cnt[i]){
            ++c;
            int w=cnt[i];
            a[c][w&1].resize(w*i+1);
            a[c][w&1^1].resize((w-1)*i+1);
            for(int j=0;j<=w;++j){
                //printf("i:%d w:%d j:%d c:%d sz0:%d sz1:%d
",i,w,j,C(w,j),a[c][0].size(),a[c][1].size());
                a[c][j&1][j*i]=C(w,j);
            }
        }
    }
    // for(int i=1;i<=c;++i){
    //     for(int j=0;j<2;++j){
    //         printf("i:%d j:%d
",i,j);
    //         for(auto &v:a[i][j]){
    //             printf("%d ",v);
    //         }
    //         puts("");
    //     }
    // }
    for(int i=2;i<=c;++i){
        tmp[0]=a[1][0]*a[i][0]+a[1][1]*a[i][1];
        tmp[1]=a[1][0]*a[i][1]+a[1][1]*a[i][0];
        tmp[0].swap(a[1][0]);
        tmp[1].swap(a[1][1]);
    }
    // for(int i=1;i<=1;++i){
    //     for(int j=0;j<2;++j){
    //         printf("i:%d j:%d
",i,j);
    //         for(auto &v:a[i][j]){
    //             printf("%d ",v);
    //         }
    //         puts("");
    //     }
    // }
    a[1][1].resize(m+1);
    printf("%d
",a[1][1][m].x);
    return 0;
}

代码2(小优化卷积次数)

少跑了100ms左右

//#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<vector>
#include<cmath>
#include<algorithm>
#include<random>
using namespace std;
#define ll long long
#define ull unsigned ll
const int N = 1<<20, P = 998244353;
const int Primitive_root = 3;//先用Get_root求出来原根然后当const用 
struct Z{
    int x;
    Z(const int _x=0):x(_x){}
    Z operator +(const Z &r)const{ return x+r.x<P?x+r.x:x+r.x-P;}
    Z operator -(const Z &r)const{ return x<r.x?x-r.x+P:x-r.x;}
    Z operator -()const{ return x?P-x:0;}
    Z operator *(const Z &r)const{ return static_cast<ull>(x)*r.x%P;}
    Z operator +=(const Z &r){ return x=x+r.x<P?x+r.x:x+r.x-P, *this;}
    Z operator -=(const Z &r){ return x=x<r.x?x-r.x+P:x-r.x, *this;}
    Z operator *=(const Z &r){ return x=static_cast<ull>(x)*r.x%P, *this;}
    friend Z Pow(Z, int);
    pair<Z,Z> Mul(pair<Z,Z> x, pair<Z,Z> y, Z f)const{
        return make_pair(
            x.first*y.first+x.second*y.second*f,
            x.second*y.first+x.first*y.second
        );
    }
    Z Quadratic_residue()const{
        if(x<=1) return x;
        if(Pow((Z)x, (P-1)/2).x!=1) return -1;
        Z y, f;
        mt19937 rng(20030226);
        do y=rng()%(x-1)+1; while(Pow(f=y*y-x, (P-1)/2).x==1);
        pair<Z,Z> ans=make_pair(1, 0), t=make_pair(y, 1);
        for(int i=(P+1)/2; i; i>>=1, t=Mul(t, t, f)) if(i&1) ans=Mul(ans, t, f);
        return min(ans.first.x, P-ans.first.x);
    }
};
Z Pow(Z x, int y=P-2){
    Z ans=1;
    for(; y; y>>=1, x=x*x) if(y&1) ans=ans*x;
    return ans;
}
namespace Poly{
    Z w[N<<1];
    Z Inv[N];
    vector<Z> ans;
    vector<vector<Z>> p;
    ull F[N];
    int Get_root(){
		static int pr[N],cnt;
		int n=P-1,sz=(int)(sqrt(n)),root=-1;
		for(int i=2;i<=sz;++i){if(n%i==0)pr[cnt++]=i;while(n%i==0)n/=i;}
		if(n>1)pr[cnt++]=n;
		for(int i=1;i<P;++i){
			if(Pow((Z)i,P-1).x==1){
				bool fl=true;
				for(int j=0;j<cnt;++j){
					if(Pow(i,(P-1)/pr[j]).x==1){
						fl=false;break;
					}
				}
				if(fl){root=i;break;}
			}
		}
		return root;
	}
    void Init(){
    	//printf("root:%d
",Primitive_root=Get_root()); 先求出来原根然后当const用 
        for(int i=1; i<N; i<<=1){
            w[i]=1;
            Z t=Pow((Z)Primitive_root, (P-1)/i/2);
            for(int j=1; j<i; ++j) w[i+j]=w[i+j-1]*t;//这里w开N应该就可,忽略最后一步?
        }
        Inv[1]=1;
        for(int i=2; i<N; ++i) Inv[i]=Inv[P%i]*(P-P/i);
    }
    int Get(int x){ int n=1; while(n<=x) n<<=1; return n;}
    int Mod(int x){ return x<P?x:x-P;}
    void DFT(vector<Z> &f, int n){
        if((int)f.size()!=n) f.resize(n);
        for(int i=0, j=0; i<n; ++i){
            F[i]=f[j].x;
            for(int k=n>>1; (j^=k)<k; k>>=1);
        }
        if(n<=4){
            for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
                Z *W=w+i;
                ull *F0=F+j, *F1=F+j+i;
                for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){
                    ull t=(*F1)*(W->x)%P;
                    (*F1)=*F0+P-t, (*F0)+=t;
                }
            }
        }
        else{
            for(int j=0; j<n; j+=2){
                int t=F[j+1];
                F[j+1]=Mod(F[j]+P-t), F[j]=Mod(F[j]+t);
            }
            for(int j=0; j<n; j+=4){
                int t0=F[j+2], t1=F[j+3]*w[3].x%P;
                F[j+2]=F[j]+P-t0, F[j]+=t0;
                F[j+3]=F[j+1]+P-t1, F[j+1]+=t1;
            }
            for(int i=4; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
                Z *W=w+i;
                ull *F0=F+j, *F1=F+j+i;
                for(int k=j; k<j+i; k+=4, W+=4, F0+=4, F1+=4){
                    int t0=(W->x)**F1%P;
                    int t1=(W+1)->x**(F1+1)%P;
                    int t2=(W+2)->x**(F1+2)%P;
                    int t3=(W+3)->x**(F1+3)%P;
                    *F1=*F0+P-t0, *F0+=t0;
                    *(F1+1)=*(F0+1)+P-t1, *(F0+1)+=t1;
                    *(F1+2)=*(F0+2)+P-t2, *(F0+2)+=t2;
                    *(F1+3)=*(F0+3)+P-t3, *(F0+3)+=t3;
                }
            }
        }
        for(int i=0; i<n; ++i) f[i]=F[i]%P;
    }
    void IDFT(vector<Z> &f, int n){
        f.resize(n), reverse(f.begin()+1, f.end()), DFT(f, n);
        Z I=1;
        for(int i=1; i<n; i<<=1) I*=(P+1)/2;
        for(int i=0; i<n; ++i) f[i]*=I;
    }
    vector<Z> operator +(const vector<Z> &f, const vector<Z> &g){
        vector<Z> ans=f;
        ans.resize(max(f.size(), g.size()));
        for(int i=0; i<(int)g.size(); ++i) ans[i]+=g[i];
        return ans;
    }
    vector<Z> operator -(const vector<Z> &f, const vector<Z> &g){
        vector<Z> ans=f;
        ans.resize(max(f.size(), g.size()));
        for(int i=0; i<(int)g.size(); ++i) ans[i]-=g[i];
        return ans;
    }
    vector<Z> operator *(const vector<Z> &f, const vector<Z> &g){
        static vector<Z> F, G;
        F=f, G=g;
        int p=Get(f.size()+g.size()-2);
        DFT(F, p), DFT(G, p);
        for(int i=0; i<p; ++i) F[i]*=G[i];
        IDFT(F, p);
        return F.resize(f.size()+g.size()-1), F;
    }
    vector<Z> operator *(const vector<Z> &f, Z g){
        vector<Z> ans=f;
        for(Z &i:ans) i*=g;
        return ans;
    }
}
using namespace Poly;
Z fac[N],ifac[N];
vector<Z>a[11][2],tmp[2];
int n,m,v,cnt[11],c;
void init(int n){
	fac[0]=1;
    for(int i=1;i<=n;++i){
		fac[i]=fac[i-1]*i;
    } 
	ifac[n]=Pow(fac[n]);
    for(int i=n;i;--i){
    	ifac[i-1]=ifac[i]*i;
    }
}
Z C(int x,int y){
    if(x<0 || y<0 || x<y)return 0;
    return fac[x]*ifac[y]*ifac[x-y];
}
int main(){
    Init();
    init(N-1);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i){
        scanf("%d",&v);
        cnt[v]++;
    }
    for(int i=1;i<=10;++i){
        if(cnt[i]){
            ++c;
            int w=cnt[i];
            a[c][w&1].resize(w*i+1);
            a[c][w&1^1].resize((w-1)*i+1);
            for(int j=0;j<=w;++j){
                //printf("i:%d w:%d j:%d c:%d sz0:%d sz1:%d
",i,w,j,C(w,j),a[c][0].size(),a[c][1].size());
                a[c][j&1][j*i]=C(w,j);
            }
        }
    }
    // for(int i=1;i<=c;++i){
    //     for(int j=0;j<2;++j){
    //         printf("i:%d j:%d
",i,j);
    //         for(auto &v:a[i][j]){
    //             printf("%d ",v);
    //         }
    //         puts("");
    //     }
    // }
    for(int i=2;i<=c;++i){
        tmp[0]=(a[1][0]+a[1][1])*(a[i][0]+a[i][1]);
        tmp[1]=a[1][0]*a[i][1]+a[1][1]*a[i][0];
        tmp[0]=tmp[0]-tmp[1];
        tmp[0].swap(a[1][0]);
        tmp[1].swap(a[1][1]);
    }
    // for(int i=1;i<=1;++i){
    //     for(int j=0;j<2;++j){
    //         printf("i:%d j:%d
",i,j);
    //         for(auto &v:a[i][j]){
    //             printf("%d ",v);
    //         }
    //         puts("");
    //     }
    // }
    a[1][1].resize(m+1);
    printf("%d
",a[1][1][m].x);
    return 0;
}