AtCoder Beginner Contest 337 G. Tree Inversion(dfs序+树状数组+树上差分 补写法)

题目

n(n<=2e5)个点的无根树,

定义f(u)为满足v<w且w在u到v的路径上的(v,w)的的数量,允许w和u重合或者w和v重合

输出f(1),f(2),...,f(n)的值

思路来源

乱搞ac

题解

赛中直接写了启发式合并过了,

属于是忽略了dfs序可以前缀作差的性质

如果按dfs序建主席树的话,感觉可作差的性质会直观很多

其实这个思路之前做17北航多校hdu6035的时候有用到过,

2017 Chinese Multi-University Training 1(C(树形dp)+F(置换群循环节)+H(nth_element)+I(仙人掌第k大生成树)+L(组合数学+dfs))_102253i - i curse myself-CSDN博客

进入子树时查询一个值,离开子树时查询一个值,

二者作差即为子树内的增量

其实就是枚举w,考虑w对哪些位置有贡献,w是子树树根枚举到的

子树内<w的值x,x到w的路径上直连儿子是v,那么在v这棵子树以外打标记

w子树外<w的值x,那么在w这棵子树以内打标记

打标记即对应区间加,实际打差分标记,左加又减

代码1(dfs序+树状数组)

#include<bits/stdc++.h>
#include<vector>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define scll(a) scanf("%lld",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d
",a)
#define ptlle(a) printf("%lld
",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
const int N=2e5+10;
int n,u,v,st[N],ed[N],dfn[N],c;
ll sum[N],ans[N];
vector<int>e[N];
struct BitPre{ // 求前缀和(可改为max等)
	int n,tr[N];
	void init(int _n){
		n=_n;
		memset(tr,0,(n+1)*sizeof(*tr));
	}
	void add(int x,int v){
		for(int i=x;i<=n;i+=i&-i)
		tr[i]+=v;
	}
	int sum(int x){
		int ans=0; 
		for(int i=x;i;i-=i&-i)
		ans+=tr[i];
		return ans;
	}
}tr;
void dfs(int u,int fa){
    tr.add(u,1);
    st[u]=++c;
    dfn[c]=u;
    int y1=tr.sum(u-1);
    for(auto &v:e[u]){
        if(v==fa)continue;
        int x1=tr.sum(u-1);
        dfs(v,u);
        int x2=tr.sum(u-1)-x1;
        sum[1]+=x2;
        sum[st[v]]-=x2;
        sum[ed[v]+1]+=x2;
    }
    ed[u]=c;
    int y2=tr.sum(u-1)-y1;
    int oth=u-1-y2;
    sum[st[u]]+=oth;
    sum[ed[u]+1]-=oth;
}
int main(){
    sci(n);
    tr.init(n);
    rep(i,2,n){
        sci(u),sci(v);
        e[u].pb(v);
        e[v].pb(u);
    }
    dfs(1,0);
    rep(i,1,n){
        sum[i]+=sum[i-1];
        ans[dfn[i]]=sum[i];
    }
    rep(i,1,n){
        printf("%lld%c",ans[i]," 
"[i==n]);
    }
    return 0;
}

代码2(启发式合并)

#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define scll(a) scanf("%lld",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d
",a)
#define ptlle(a) printf("%lld
",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
const int N=2e5+10;
int n,u,v,st[N],ed[N],dfn[N],sz[N],c;
ll sum[N],ans[N];
vector<int>e[N];
unordered_map<int,int>now;
struct BitPre{ // 求前缀和(可改为max等)
	int n,tr[N];
	void init(int _n){
		n=_n;
		memset(tr,0,(n+1)*sizeof(*tr));
	}
	void add(int x,int v){
		for(int i=x;i<=n;i+=i&-i)
		tr[i]+=v;
	}
	int sum(int x){
		int ans=0; 
		for(int i=x;i;i-=i&-i)
		ans+=tr[i];
		return ans;
	}
}tr;
void dfs(int u,int fa){
    st[u]=++c;
    sz[u]=1;
    dfn[c]=u;
    for(auto &v:e[u]){
        if(v==fa)continue;
        dfs(v,u);
        sz[u]+=sz[v];
    }
    ed[u]=c;
}
void dfs(int u,int fa,bool keep){
    int mx=-1,son=-1;
    for(auto v:e[u]){
        if(v!=fa&&sz[v]>mx)
            mx=sz[v],son=v;
    }
    for(auto &v:e[u]){
        if(v!=fa&&v!=son){
            dfs(v,u,0);
        }
    }
    if(son!=-1){
        dfs(son,u,1);
        int z=tr.sum(u);
        sum[1]+=z;
        sum[st[son]]-=z;
        sum[ed[son]+1]+=z;
    }
    for(auto &v:e[u]){
        if(v!=fa&&v!=son){
            int z=0;
            for(int i=st[v];i<=ed[v];i++){
                int x=dfn[i];
                z+=(x<u);
                now[x]=1;
                tr.add(x,1);
            }
            sum[1]+=z;
            sum[st[v]]-=z;
            sum[ed[v]+1]+=z;
            // for(int i=st[v];i<=ed[v];i++){
            //     now[x]=1;
            //     tr.add(x,1);
            // }
        }
    }
    int z=tr.sum(u),oth=u-1-z;
    sum[st[u]]+=oth;
    sum[ed[u]+1]-=oth;
    // sum[1]+=z;
	// sum[st[u]+1]-=z;
	// sum[ed[u]+1]+=z; 
    now[u]=1;
    tr.add(u,1);
    if(keep==0){
        for(auto &x:now){
            tr.add(x.fi,-1);
        }
    	now.clear();
    }
}
int main(){
    sci(n);
    tr.init(n);
    rep(i,2,n){
        sci(u),sci(v);
        e[u].pb(v);
        e[v].pb(u);
    }
    dfs(1,0);
    dfs(1,0,0);
    rep(i,1,n){
        sum[i]+=sum[i-1];
        ans[dfn[i]]=sum[i];
    }
    rep(i,1,n){
        printf("%lld%c",ans[i]," 
"[i==n]);
    }
    return 0;
}