题目
给定m,n(m<=n<=5e3),
求大小为k的多重集合,满足元素和为n,
且每种数在集合中出现的次数都小于等于m的集合数有多少个
答案对998244353取模
思路来源
官方题解
「解题报告」[ABC221H] Count Multiset - K8He - 洛谷博客
Solution-ABC221H - yllcm 的博客 - 洛谷博客
【AtCoder思维训练】ABC221H Count Multiset - QAQ - 洛谷博客
题解1
整体来说,如果没有每个次数<=m的限制,就是分拆数
1. 把多重集合转成不下降序列(单增序列),
每个序列统计一次(A1,A2,...,Ak)(A1<=A2<=...<=Ak)
2. 把不下降序列转成差分数组,
令B[1]=A[1],B[i]=A[i]-A[i-1],
对于差分数组,需要满足以下三个条件:
①
②B数组中不存在连续的m个0
③
3. 发现①做dp的时候是有后效性的,
与k相关, 第k+1次的时候需要加上前k个的和
考虑对差分数组反转,即令i=k+1-i
反转后的差分数组B,需要满足以下三个条件:
①
②B数组中不存在连续的m个0
③
f[i][j]表示当前选了i个数,总和为j的方案数
①一种方式是,对反转的差分序列后面新增一个0,
如:
原序列2 2 3,差分序列2 0 1,反转差分序列1 0 2,
此时给反转差分序列后面加一个0,得到1 0 2 0,
对应差分序列0 2 0 1,原序列0 2 2 3,即原序列前面加一个0
即f[i][j]从f[i-1][j]转移而来
②另一种方式是,对反转的差分序列的最后一个数加1,
如:
原序列2 2 3,差分序列2 0 1,反转差分序列1 0 2,
此时给反转差分序列最后一个数加1,得到1 0 3,
对应差分序列3 0 1,原序列3 3 4,即原序列整体加1
即f[i][j]从f[i][j-i]转移而来
考虑怎么加上连续最多m个0的限制,
设g[i][j]表示当前填了i个数,总和为j,序列里不含0的方案数,
给整体加1过后的序列,即不包含0,有
f的转移,要么是对f数组整体加1,
要么是钦定0的个数,从一段没有0的g数组转移过来
由于序列里没有0,最后g[i][n]即为所求
当然,可以进一步化简,
因为最后是求g数组,可以上式代入下式联立消掉f,有
,
就与官方题解中的代码一致了,前缀和优化一下,复杂度
题解2
接题解1,反转后的差分数组B,需要满足以下三个条件:
①
②B数组中不存在连续的m个0
③
直接g[i][j]表示当前选了i个数,,最后一个数即b[i]>0的方案数
考虑暴力转移,
从1到m,枚举最后一段0的连续段长度,
也就是枚举上一个非0的位置x,再枚举b[i]选择的数为w,有:
对的第一维,也就是g[x]这一维维护前缀和,
即可实现转移,复杂度
题解3
考虑直接对原序列做dp,
f[i][j]表示前i个数和为j的方案数
如:原序列1 1 2,
①每次要么新增一个1,转移到1 1 1 2,f[i][j]从f[i-1][j-1]转移
②要么令所有数都+1,使得所有数都大于等于2,转移到2 2 3,f[i][j]从f[i][j-i]转移
但是,第一种转移新增了一个1,可能会导致恰出现连续m+1个1的情况,减掉这种情况即可
出现这种情况时,前m+1个数字为1,且第m+2个数为>=2的值,
只需全局减1,即可删掉m+1个1,并且使得第m+2个数的值>=1,也就对应了f[i-(m+1)][j-i]
有
复杂度
题解4
数形结合,
如果对原序列dp,如下图所示,有三条限制,
①
②不存在超过m个xi相同
③
按照箭头视角去看这个图,
也就是先顺时针旋转90度,再翻转,
新的序列仍然有三条限制,
①
②
③
发现限制2更强了,所以可以对新序列dp,
dp[i][j]表示最后一列高为i,柱状图面积总和为j的方案数,
枚举上一列高为x,需要满足x∈[i-m,i],有:
惊奇地发现,这和题解1得到的转移式子一模一样
复杂度
代码1、代码4 O(n^2)
#include<iostream> using namespace std; const int N=5e3+10,mod=998244353; int n,m,dp[N][N],sum[N][N]; void add(int &x,int y){ x=(x+y)%mod; } int main(){ scanf("%d%d",&n,&m); dp[0][0]=sum[0][0]=1; for(int i=1;i<=n;++i){ for(int j=0;j<=n;++j){ if(j>=i){ dp[i][j]=sum[i][j-i]; if(i-m-1>=0){ add(dp[i][j],mod-sum[i-m-1][j-i]); } } sum[i][j]=(sum[i-1][j]+dp[i][j])%mod; } } for(int i=1;i<=n;++i){ printf("%d ",dp[i][n]); } return 0; }
代码2 O(n^2logn)
#include<iostream> using namespace std; const int N=5e3+10,mod=998244353; int n,m,g[N][N],sum[N][N]; void add(int &x,int y){ x=(x+y)%mod; } int main(){ scanf("%d%d",&n,&m); g[0][0]=sum[0][0]=1; for(int i=1;i<=n;++i){ for(int j=0;j<=n;++j){ for(int w=1;w*i<=j;++w){ add(g[i][j],sum[i-1][j-w*i]); if(i-m-1>=0)add(g[i][j],mod-sum[i-m-1][j-w*i]); } sum[i][j]=(sum[i-1][j]+g[i][j])%mod; } } for(int i=1;i<=n;++i){ printf("%d ",g[i][n]); } return 0; }
代码3 O(n^2)
#include<iostream> using namespace std; const int N=5e3+10,mod=998244353; int n,m,dp[N][N]; void add(int &x,int y){ x=(x+y)%mod; } int main(){ scanf("%d%d",&n,&m); dp[0][0]=1; for(int i=1;i<=n;++i){ for(int j=1;j<=n;++j){ dp[i][j]=dp[i-1][j-1]; if(j-i>=0)add(dp[i][j],dp[i][j-i]); if(i>=m+1 && j-i>=0)add(dp[i][j],mod-dp[i-(m+1)][j-i]); } } for(int i=1;i<=n;++i){ printf("%d ",dp[i][n]); } return 0; }
代码5 O(n^3)
自己乱搞了两个复杂度并不正确的做法,也贴在这里好了
这个是考虑容斥减掉不合法的答案
#include<iostream> using namespace std; const int N=5e3+10,mod=998244353; typedef long long ll; int n,m,dp[N][N],sum[N];//dp[i][j]选了i个和为j方案数 void add(int &x,int y){x=(x+y)%mod;} int main(){ scanf("%d%d",&n,&m); dp[0][0]=1; for(int l=1;l<=n;++l){ for(int i=1;i<=n;++i){ for(int j=l;j<=n;++j){ add(dp[i][j],dp[i-1][j-l]); /* for(int k=1;k<=j;++k){ add(dp[i][j],dp[i-1][j-k]); } */ } } for(int i=n;i>=m+1;--i){ for(int j=n;j-l*(m+1)>=0;--j){ add(dp[i][j],mod-dp[i-(m+1)][j-l*(m+1)]); } } } // for(int i=1;i<=n;++i){ // for(int j=1;j<=n;++j){ // printf("i:%d j:%d dp:%d ",i,j,dp[i][j]); // } // } for(int i=1;i<=n;++i){ printf("%d ",dp[i][n]); } return 0; }
代码6 O(n^3logn)
这个是纯纯暴力
#include<iostream> using namespace std; const int N=5e3+10,mod=998244353; typedef long long ll; int n,m,dp[N][N];//dp[i][j]选了i个和为j方案数 void add(int &x,int y){x=(x+y)%mod;} int main(){ scanf("%d%d",&n,&m); dp[0][0]=1; for(int i=1;i<=n;++i){ for(int j=n;j>=i;--j){ for(int k=1;k<=m;++k){ if(j-k*i<0)break; for(int l=n;l>=k;--l){ add(dp[l][j],dp[l-k][j-k*i]); } } } } for(int i=1;i<=n;++i){ printf("%d ",dp[i][n]); } return 0; }