树链剖分算法总结
树链剖分是一种把树剖分成重链和轻链,并用dfs序储存在线段树中的算法。它可以方便的处理树上路径和子树的问题。把树上数据存在线段树中的思想值得思考。
何为树链剖分?树链,就是树上路径,剖分,就是把树链剖分成轻链和重链。
记siz[v]表示以v为根的子树的节点数,dep[v]表示v的深度,top[v]表示v所在的重链的顶端节点,fa[v]表示v的父亲,son[v]表示重儿子,dfs_id[v]v的dfs序。
先介绍几个概念:
这样,很显然的我们就能发现
1.如果(v,u)为轻边,则siz[u] * 2 < siz[v];
2.从根到某一点的路径上轻链、重链的个数都不大于logn。
这两个很好的性质就可以在logn的复杂度下遍历任意一个路径。我们可以两个点同时向上跳,假如是重链就跳到top,不是就跳到父亲。直到跳到两点的top是同一个。跳的同时就可以用线段树维护一下极值、求和啥的。
图片来自网络
如何实现呢?
我们可以通过两个dfs实现
第一个
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
void dfs_1(int x,int f) { siz[x]=1; fa[x]=f; for (int i=st[x];i;i=e[i].next) if (e[i].to!=f) { dep[e[i].to]=dep[x]+1; dfs_1(e[i].to,x); siz[x]+=siz[e[i].to]; if (siz[e[i].to]>siz[son[x]]) son[x]=e[i].to; } } |
在这个dfs中,可以把siz,fa,dep,son求出来
第二个
1 2 3 4 5 6 7 8 9 10 11 12 13 |
int tot2=0; void dfs_2(int now,int tp) { pre[++tot2]=now; dfs_id[now]=tot2; top[now]=tp; if (son[now]) dfs_2(son[now],tp); for (int i=st[now];i;i=e[i].next) if (e[i].to!=son[now] && e[i].to!=fa[now]) dfs_2(e[i].to,e[i].to); } //基本思想就是现在是重链上的话,就用原来的top,不是重链上的就传自己作为top |
这个dfs可以把top,dfs_id求出来
pre就是dfs_id的反函数。。它是当构造线段树时候用的。
构造线段树
1 2 3 4 5 6 7 8 9 10 11 12 |
void build(int rt,int l,int r) { if (l==r) tree[rt].sum=tree[rt].maxs=tree[rt].mins=val[pre[r]]; else { int mid=(l+r)/2; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } } |
val数字存的是每个点的权值。因为线段树用的是dfs序,我们直接调用反函数就可以知道权值
例题是这道
Description
Ray 乐忠于旅游,这次他来到了T 城。T 城是一个水上城市,一共有 N 个景点,有些景点之间会用一座桥连接。为了方便游客到达每个景点但又为了节约成本,T 城的任意两个景点之间有且只有一条路径。换句话说, T 城中只有N − 1 座桥。Ray 发现,有些桥上可以看到美丽的景色,让人心情愉悦,但有些桥狭窄泥泞,令人烦躁。于是,他给每座桥定义一个愉悦度w,也就是说,Ray 经过这座桥会增加w 的愉悦度,这或许是正的也可能是负的。有时,Ray 看待同一座桥的心情也会发生改变。现在,Ray 想让你帮他计算从u 景点到v 景点能获得的总愉悦度。有时,他还想知道某段路上最美丽的桥所提供的最大愉悦度,或是某段路上最糟糕的一座桥提供的最低愉悦度。
Input
输入的第一行包含一个整数N,表示T 城中的景点个数。景点编号为 0…N − 1。接下来N − 1 行,每行三个整数u、v 和w,表示有一条u 到v,使 Ray 愉悦度增加w 的桥。桥的编号为1…N − 1。|w| <= 1000。输入的第N + 1 行包含一个整数M,表示Ray 的操作数目。接下来有M 行,每行描述了一个操作,操作有如下五种形式: C i w,表示Ray 对于经过第i 座桥的愉悦度变成了w。 N u v,表示Ray 对于经过景点u 到v 的路径上的每一座桥的愉悦度都变成原来的相反数。 SUM u v,表示询问从景点u 到v 所获得的总愉悦度。 MAX u v,表示询问从景点u 到v 的路径上的所有桥中某一座桥所提供的最大愉悦度。 MIN u v,表示询问从景点u 到v 的路径上的所有桥中某一座桥所提供的最小愉悦度。测试数据保证,任意时刻,Ray 对于经过每一座桥的愉悦度的绝对值小于等于1000。
Output
对于每一个询问(操作S、MAX 和MIN),输出答案。
Sample Input
3
0 1 1
1 2 2
8
SUM 0 2
MAX 0 2
N 0 1
SUM 0 2
MIN 0 2
C 1 3
SUM 0 2
MAX 0 2
Sample Output
3
2
1
-1
5
3
HINT
一共有10 个数据,对于第i (1 <= i <= 10) 个数据, N = M = i * 2000。
这题有四个操作,区间最大,区间最小,区间和,区间反转
我们该如何操作呢?
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
void doit() { int l,r,f1,f2,sum=0,maxs=-INF,mins=INF; scanf("%d%d",&l,&r),l++,r++; f1=top[l],f2=top[r]; while(f1!=f2) { if (dep[f1]<dep[f2]) swap(f1,f2),swap(l,r); if (ch=='N') res(1,1,tot2,dfs_id[f1],dfs_id[l]); else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[f1],dfs_id[l]); else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[f1],dfs_id[l]),mins); else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[f1],dfs_id[l]),maxs); l=fa[f1],f1=top[l]; } if (dep[l]>dep[r]) swap(l,r); if (l==r) { if (ch=='S') printf("%d\n",sum); if (ch=='I') printf("%d\n",mins); if (ch=='A') printf("%d\n",maxs); } else { l=son[l]; if (ch=='N') res(1,1,tot2,dfs_id[l],dfs_id[r]); else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[l],dfs_id[r]),printf("%d\n",sum); else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[l],dfs_id[r]),mins),printf("%d\n",mins); else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[l],dfs_id[r]),maxs),printf("%d\n",maxs); } } |
首先,找到这两个节点,把高度低的那个往上跳,跳的时候操作一下这个链,最后直到top一样或到同一个点
注意一下,假如不在一个点的话最后还要更新一下。
总代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
#include<bits/stdc++.h> #define INF 0x7fffffff #define M 40010 #define N 20010 using namespace std; typedef pair<int,int> Pair; struct node { int from,to,value,next; }e[M]; struct seg { int sum,mins,maxs,mark; }tree[4*N]; int tot,st[M],n,m,siz[N],son[N],fa[N],pre[3*N],top[N],dfs_id[N],dep[N],val[N]; char ch; void add(int x,int y,int z) { e[++tot].to=y; e[tot].from=x; e[tot].value=z; e[tot].next=st[x]; st[x]=tot; } void dfs_1(int x,int f) { siz[x]=1; fa[x]=f; for (int i=st[x];i;i=e[i].next) if (e[i].to!=f) { dep[e[i].to]=dep[x]+1; dfs_1(e[i].to,x); siz[x]+=siz[e[i].to]; if (siz[e[i].to]>siz[son[x]]) son[x]=e[i].to; } } int tot2=0; void dfs_2(int now,int tp) { pre[++tot2]=now; dfs_id[now]=tot2; top[now]=tp; if (son[now]) dfs_2(son[now],tp); for (int i=st[now];i;i=e[i].next) if (e[i].to!=son[now] && e[i].to!=fa[now]) dfs_2(e[i].to,e[i].to); } void re(int &a,int &b,int &c){a=-a,b=-b,c=-c;} void pushdown(int now) { if (tree[now].mark==0) return; tree[now<<1].mark^=1; tree[now<<1|1].mark^=1; swap(tree[now<<1].mins,tree[now<<1].maxs); re(tree[now<<1].sum,tree[now<<1].mins,tree[now<<1].maxs); swap(tree[now<<1|1].mins,tree[now<<1|1].maxs); re(tree[now<<1|1].sum,tree[now<<1|1].mins,tree[now<<1|1].maxs); tree[now].mark=0; } void pushup(int now) { tree[now].maxs=max(tree[now<<1].maxs,tree[now<<1|1].maxs); tree[now].mins=min(tree[now<<1].mins,tree[now<<1|1].mins); tree[now].sum=tree[now<<1].sum+tree[now<<1|1].sum; } void build(int rt,int l,int r) { if (l==r) tree[rt].sum=tree[rt].maxs=tree[rt].mins=val[pre[r]]; else { int mid=(l+r)/2; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } } void update(int rt,int l,int r,int pos,int x) { if (l==r) { tree[rt].mark=0; tree[rt].sum=tree[rt].maxs=tree[rt].mins=x; return; } pushdown(rt); int mid=(r+l)/2; if (mid>=pos) update(rt<<1,l,mid,pos,x); else update(rt<<1|1,mid+1,r,pos,x); pushup(rt); } void res(int rt,int l,int r,int L,int R) { if (L<=l && r<=R) { tree[rt].mark^=1; swap(tree[rt].maxs,tree[rt].mins); re(tree[rt].maxs,tree[rt].mins,tree[rt].sum); return; } pushdown(rt); int mid=(l+r)/2; if (mid>=L) res(rt<<1,l,mid,L,R); if (mid<R) res(rt<<1|1,mid+1,r,L,R); pushup(rt); } int get_max(int rt,int l,int r,int L,int R) { if (L<=l && r<=R) return tree[rt].maxs; pushdown(rt); int ans=-INF,mid=(r+l)/2; if (mid>=L) ans=max(ans,get_max(rt<<1,l,mid,L,R)); if (mid<R) ans=max(ans,get_max(rt<<1|1,mid+1,r,L,R)); return ans; } int get_min(int rt,int l,int r,int L,int R) { if (L<=l && r<=R) return tree[rt].mins; pushdown(rt); int ans=INF,mid=(r+l)/2; if (mid>=L) ans=min(ans,get_min(rt<<1,l,mid,L,R)); if (mid<R) ans=min(ans,get_min(rt<<1|1,mid+1,r,L,R)); return ans; } int get_sum(int rt,int l,int r,int L,int R) { if (L<=l && r<=R) return tree[rt].sum; pushdown(rt); int ans=0,mid=(r+l)/2; if (mid>=L) ans+=get_sum(rt<<1,l,mid,L,R); if (mid<R) ans+=get_sum(rt<<1|1,mid+1,r,L,R); return ans; } void doit() { int l,r,f1,f2,sum=0,maxs=-INF,mins=INF; scanf("%d%d",&l,&r),l++,r++; f1=top[l],f2=top[r]; while(f1!=f2) { if (dep[f1]<dep[f2]) swap(f1,f2),swap(l,r); if (ch=='N') res(1,1,tot2,dfs_id[f1],dfs_id[l]); else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[f1],dfs_id[l]); else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[f1],dfs_id[l]),mins); else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[f1],dfs_id[l]),maxs); l=fa[f1],f1=top[l]; } if (dep[l]>dep[r]) swap(l,r); if (l==r) { if (ch=='S') printf("%d\n",sum); if (ch=='I') printf("%d\n",mins); if (ch=='A') printf("%d\n",maxs); } else { l=son[l]; if (ch=='N') res(1,1,tot2,dfs_id[l],dfs_id[r]); else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[l],dfs_id[r]),printf("%d\n",sum); else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[l],dfs_id[r]),mins),printf("%d\n",mins); else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[l],dfs_id[r]),maxs),printf("%d\n",maxs); } } main() { scanf("%d",&n); int x,y,z; for (int i=1;i<n;i++) scanf("%d%d%d",&x,&y,&z),x++,y++, add(x,y,z),add(y,x,z); dfs_1(1,0); dfs_2(1,1); for (int i=1;i<=tot;i++) { if (dep[e[i].from]>dep[e[i].to]) swap(e[i].from,e[i].to); val[e[i].to]=e[i].value; } build(1,1,tot2); scanf("%d",&m); while(m--) { ch=getchar(); while(ch!='N'&&ch!='S'&&ch!='M'&&ch!='C') ch=getchar(); if (ch=='C') { scanf("%d%d",&x,&y); update(1,1,tot2,dfs_id[e[x<<1].to],y); } else if (ch=='S') getchar(),getchar(),doit(); else if (ch=='N') doit(); else if (ch=='M') ch=getchar(),getchar(),doit(); } } |
还有一种情况就是操作子树
其实这个更简单
我们可以观察一下一棵树的dfs序
我们可以观察到子树是在dfs序上连续的一段:dfs_id[i]+1到dfs_id[i]+siz[i]-1
然后直接线段树就行了
例题
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#include<cstdio> #include<iostream> #define ls k*2 #define rs k*2+1 using namespace std; int n,m,l,hs,nl,lfs,root,mod; int a,b,c,d,ans,com; int s[100010],h[100010]; struct node { int k,f,d,sz,ws,p,t; }p[100010]; struct nate { int s,n; }e[200010]; struct tree { int l,r,s,f; }t[400010]; inline void in(int &ans){ans=0;bool p=false;char ch=getchar();while((ch>'9' || ch<'0')&&ch!='-') ch=getchar();if(ch=='-') p=true,ch=getchar();while(ch<='9'&&ch>='0') ans=ans*10+ch-'0',ch=getchar();if(p) ans=-ans;} void add(int x,int y) { e[++hs]=(nate){y,h[x]};h[x]=hs; } void pushdown(int k) { t[ls].f=(t[ls].f+t[k].f)%mod; t[ls].s+=(t[ls].r-t[ls].l+1)*t[k].f%mod; t[rs].f=(t[rs].f+t[k].f)%mod; t[rs].s+=(t[rs].r-t[rs].l+1)*t[k].f%mod; t[k].f=0; } void build(int k,int l,int r) { t[k].l=l;t[k].r=r; if(l==r){t[k].s=s[++nl];return;} int mid=(l+r)/2; build(ls,l,mid); build(rs,mid+1,r); t[k].s=t[ls].s+t[rs].s; } void change(int k,int l,int r,int v) { if(t[k].l==l&&t[k].r==r) { t[k].f=(t[k].f+v)%mod; t[k].s+=(t[k].r-t[k].l+1)*v%mod; return; } if(t[k].f) pushdown(k); int mid=(t[k].l+t[k].r)/2; if(l<=mid) change(ls,l,min(r,mid),v); if(r>mid) change(rs,max(l,mid+1),r,v); t[k].s=(t[ls].s+t[rs].s)%mod; } int query(int k,int l,int r) { if(t[k].l==l&&t[k].r==r) return t[k].s; if(t[k].f) pushdown(k); int mid=(t[k].l+t[k].r)/2,ans=0; if(l<=mid) ans+=query(ls,l,min(r,mid))%mod; if(r>mid) ans+=query(rs,max(l,mid+1),r)%mod; return ans%mod; } void dfs1(int k,int f,int d) { p[k].f=f;p[k].d=d;p[k].sz=1; for(int i=h[k];i;i=e[i].n) if(e[i].s!=f){ dfs1(e[i].s,k,d+1); p[k].sz+=p[e[i].s].sz; if(p[e[i].s].sz>p[p[k].ws].sz) p[k].ws=e[i].s; } } void dfs2(int k) { s[++l]=p[k].k;p[k].p=l; if(p[k].ws) { p[p[k].ws].t=p[k].t; dfs2(p[k].ws); } for(int i=h[k];i;i=e[i].n) if(e[i].s!=p[k].ws&&e[i].s!=p[k].f) { p[e[i].s].t=e[i].s; dfs2(e[i].s); } } int main() { in(n),in(m),in(root),in(mod); for(int i=1;i<=n;i++) in(p[i].k); for(int i=1;i<n;i++) in(a),in(b),add(a,b),add(b,a); dfs1(root,root,1); dfs2(root); build(1,1,l); while(m--) { in(com); if(com==1) { in(b),in(c),in(d);d%=mod; for(;p[b].t!=p[c].t;b=p[p[b].t].f) { if(p[p[b].t].d<p[p[c].t].d) swap(b,c); change(1,p[p[b].t].p,p[b].p,d); } if(p[b].d>p[c].d) swap(b,c); change(1,p[b].p,p[c].p,d); } if(com==2) { in(b),in(c);ans=0; for(;p[b].t!=p[c].t;b=p[p[b].t].f) { if(p[p[b].t].d<p[p[c].t].d) swap(b,c); ans+=query(1,p[p[b].t].p,p[b].p),ans%=mod; } if(p[b].d>p[c].d) swap(b,c); ans+=query(1,p[b].p,p[c].p),ans%=mod; printf("%d\n",ans); } if(com==3) { in(b),in(c),c%=mod; change(1,p[b].p,p[b].p+p[b].sz-1,c); } if(com==4) { in(b),ans=0; ans=query(1,p[b].p,p[b].p+p[b].sz-1)%mod; printf("%d\n",ans); } } } |
The end.