题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=3252
显而易见的是每次应该选价值最高的那条链
然后对于链上的每个点会使它子树里的叶子结点减去它的价值
用线段树维护下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
#include<bits/stdc++.h> #define MAXN 200020 typedef long long ll; using namespace std; inline int rd(){ int x=0,y=1;char c=getchar(); while(!isdigit(c)){if(c=='-')y=-y;c=getchar();} while(isdigit(c))x=x*10+c-'0',c=getchar(); return x*y; } struct node{ int l,r,num; ll Max,lazy; }tree[MAXN*4]; struct line{ int l,r; line(int x,int y){l=x,r=y;} line(){} }; int n,m; int lnum[MAXN],rnum[MAXN],fa[MAXN],tot,totree[MAXN]; ll sum[MAXN],val[MAXN],ans; bool vis[MAXN]; vector<int>v[MAXN]; line dfs(int x){ sum[x]+=val[x]; lnum[x]=MAXN,rnum[x]=0; for(int i=0; i<v[x].size(); i++){ int to=v[x][i]; sum[to]+=sum[x]; line li=dfs(to); lnum[x]=min(lnum[x],li.l); rnum[x]=max(rnum[x],li.r); } if(!v[x].size())lnum[x]=rnum[x]=++tot,totree[tot]=x; return line(lnum[x],rnum[x]); } inline void pushup(int x){ if(tree[x*2].Max>=tree[x*2+1].Max)tree[x].num=tree[x*2].num,tree[x].Max=tree[x*2].Max; else tree[x].num=tree[x*2+1].num,tree[x].Max=tree[x*2+1].Max; } inline void pushdown(int x){ if(tree[x].l==tree[x].r)return; if(tree[x].lazy){ ll add=tree[x].lazy; tree[x*2].Max+=add; tree[x*2].lazy+=add; tree[x*2+1].Max+=add; tree[x*2+1].lazy+=add; } tree[x].lazy=0; } void build(int x,int l,int r){ tree[x].l=l; tree[x].r=r; if(l==r){ tree[x].Max=sum[totree[l]]; tree[x].num=l; return; } int mid=(l+r)/2; build(x*2,l,mid); build(x*2+1,mid+1,r); pushup(x); } void change(int x,int l,int r,ll k){ pushdown(x); if(tree[x].l==l && tree[x].r==r){ tree[x].Max+=k; tree[x].lazy+=k; return; } int mid=(tree[x].l+tree[x].r)/2; if(r<=mid)change(x*2,l,r,k); else if(l>mid)change(x*2+1,l,r,k); else change(x*2,l,mid,k),change(x*2+1,mid+1,r,k); pushup(x); } int main(){ n=rd(),m=rd(); for(int i=1; i<=n; i++)val[i]=(ll)rd(); for(int i=1; i<n; i++){ int x=rd(),y=rd(); v[x].push_back(y); fa[y]=x; } dfs(1); if(m>=tot){ for(int i=1; i<=n; i++)ans+=val[i]; printf("%lld\n",ans); return 0; } build(1,1,tot); while(m--){ int now=totree[tree[1].num]; ans+=tree[1].Max; while(!vis[now] && now){ vis[now]=1; change(1,lnum[now],rnum[now],-val[now]); now=fa[now]; } } printf("%lld\n",ans); return 0; } |