模拟赛的T2,多敲了两行成功爆掉~
写线段树合并的时候一定要注意一下不能随意新开节点.
code:
#include <bits/stdc++.h>
#define N 100009
#define ll long long
#define setIO(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
using namespace std;
int n,edges;
int A[N],hd[N],to[N<<1],nex[N<<1],kk[N],rt[N],ans1[N];
ll val[N<<1];
ll ans2[N];
void add(int u,int v,int c)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=c;
}
struct Segment_Tree
{
#define lson p[x].ls
#define rson p[x].rs
int tot;
struct Node
{
int ls,rs,size;
ll dis,num,rt,mul,maxx;
}p[N*90];
int newnode() { return ++tot; }
void mark(int x,ll d)
{
p[x].mul+=d;
p[x].rt+=d*p[x].num;
}
void pushdown(int x,int l,int r)
{
if(p[x].mul)
{
int mid=(l+r)>>1;
if(lson) mark(lson, p[x].mul);
if(rson) mark(rson, p[x].mul);
p[x].mul=0;
}
}
int merge(int l,int r,int u,int v)
{
if(!u||!v) return u+v;
pushdown(u,l,r);
pushdown(v,l,r);
int now=newnode();
p[now].dis=p[u].dis+p[v].dis+p[u].num*p[v].rt+p[u].rt*p[v].num;
p[now].num=p[u].num+p[v].num;
p[now].rt=p[u].rt+p[v].rt;
p[now].maxx=max(p[u].maxx, p[v].maxx);
if(l==r)
{
if(p[now].num) p[now].size=1;
p[now].maxx=p[now].dis;
return now;
}
int mid=(l+r)>>1;
p[now].ls=merge(l,mid,p[u].ls,p[v].ls);
p[now].rs=merge(mid+1,r,p[u].rs,p[v].rs);
p[now].size=p[p[now].ls].size+p[p[now].rs].size;
p[now].maxx=max(p[p[now].ls].maxx, p[p[now].rs].maxx);
return now;
}
int solve(int l,int r,int x)
{
if(l==r) return l;
int mid=(l+r)>>1;
pushdown(x,l,r);
if(l<=mid && p[lson].size && p[lson].maxx==p[x].maxx) return solve(l,mid,lson);
else return solve(mid+1,r,rson);
}
void update(int &x,int l,int r,int pp)
{
if(!x) x=newnode();
if(l==r)
{
p[x].size=1;
p[x].num=1;
return;
}
pushdown(x, l, r);
int mid=(l+r)>>1;
if(pp<=mid) update(lson,l,mid,pp);
else update(rson,mid+1,r,pp);
p[x].maxx=max(p[lson].maxx, p[rson].maxx);
p[x].size=p[lson].size+p[rson].size;
}
ll dfss(int l,int r,int x,int kth)
{
if(l==r) return p[x].dis;
int mid=(l+r)>>1;
pushdown(x,l,r);
int sz=p[lson].size;
if(sz>=kth) return dfss(l,mid,lson,kth);
else return dfss(mid+1,r,rson,kth-sz);
}
#undef lson
#undef rson
}seg;
void dfs(int u,int ff,int pp)
{
seg.update(rt[u],1,n,A[u]);
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff) continue;
dfs(v,u,val[i]);
}
if(seg.p[rt[u]].size<kk[u]) ans2[u]=-1;
else
{
ans2[u]=seg.dfss(1,n,rt[u],kk[u]);
}
ans1[u]=seg.solve(1,n,rt[u]);
seg.mark(rt[u], 1ll*pp);
rt[ff]=seg.merge(1,n,rt[u], rt[ff]);
}
int main()
{
int i,j;
// setIO("input");
scanf("%d",&n);
for(i=1;i<n;++i)
{
int u,v,c;
scanf("%d%d%d",&u,&v,&c), add(u,v,c), add(v,u,c);
}
for(i=1;i<=n;++i) scanf("%d",&A[i]);
for(i=1;i<=n;++i) scanf("%d",&kk[i]);
dfs(1,0,0);
for(i=1;i<=n;++i) printf("%d %lld\n",ans1[i],ans2[i]);
return 0;
}
知识兔