LuoguP5327「ZJOI2019」题解

树链剖分 + 树上差分 + 线段树合并

Description

九条可怜是一个喜欢规律的女孩子。按照规律,第二题应该是一道和数据结构有关的题。

在一个遥远的国度,有 $n$ 个城市。城市之间有 $n − 1$条双向道路,这些道路保证了任何两个城市之间都能直接或者间接地到达。

在上古时代,这 $n$ 个城市之间处于战争状态。在高度闭塞的环境中,每个城市都发展出了自己的语言。而在王国统一之后,语言不通给王国的发展带来了极大的阻碍。为了改善这种情况,国王下令设计了 $m$ 种通用语,并进行了 $m$ 次语言统一工作。在第 $i$ 次统一工作中,一名大臣从城市 $s_i$ 出发,沿着最短的路径走到了 $t_i$ ​
,教会了沿途所有城市(包括 $s_i, t_i$) 使用第 $i$ 个通用语。

一旦有了共通的语言,那么城市之间就可以开展贸易活动了。两个城市 $u_i, v_i$ 之间可以开展贸易活动当且仅当存在一种通用语 $L$ 满足 $u_i$ 到 $v_i$ 最短路上的所有城市(包括 $u_i, v_i$​),都会使用 $L$.

为了衡量语言统一工作的效果,国王想让你计算有多少对城市 $(u, v)\ (u < v)$ ,他们之间可以开展贸易活动。

Samples

Input #1

1
2
3
4
5
6
7
8
5 3
1 2
1 3
3 4
3 5
3 4
1 4
2 5

Output #1

1
8

Explanation

可以开展贸易活动的城市对为 $(1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 5), (3, 4), (3, 5)$

Solution

首先, 如果两个城市$u, v$能够开展贸易, 我们称”$u$能到达$v$”或”$v$能到达$u$”.

考虑维护维护每个节点能到达的节点集合$S_u$, 显然答案是$\frac {\sum_{i=1}^n S_i}{2}$(因为这里没有保证$u < v$)

将树树剖, 这样把每一次修改的路径拆成了不超过$\log n$个区间, 如果对每一个节点的$S_u$都直接用线段树维护, 那么就需要把每个节点跑一遍$\log n$的区间覆盖, 总复杂度$\mathcal O(n^2 \log^2 n)$, 可以得到暴力分的好成绩.

容易发现, 我们现在需要做的是对树上一条路径$u \to v$的推平, 首先讨论的求出$k = \rm{lca(} u, v \rm{)}$ , 那么等同于推平$u \to k$, $k \to v$两条链.

链上修改就可以很显然的转化为树上差分打标记, 分别给$u, v, k, fa[k]$ 打上 $1, 1, -1, -1$ 的标记.

然后因为每个节点$u$需要用到它的儿子的信息, 而我们又不能直接开$n$棵线段树, 那么就可以用线段树合并来完成.

注意: 因为这里是差分, 所以区间覆盖会变成区间加, 写的时候要注意一下.

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
/* Headers */
#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<cctype>
#include<vector>
#include<cmath>
#include<array>
#include<queue>
#include<stack>
#include<map>
#define FOR(i,a,b,c) for(int i=(a);i<=(b);i+=(c))
#define ROF(i,a,b,c) for(int i=(a);i>=(b);i-=(c))
#define FORL(i,a,b,c) for(long long i=(a);i<=(b);i+=(c))
#define ROFL(i,a,b,c) for(long long i=(a);i>=(b);i-=(c))
#define FORR(i,a,b,c) for(register int i=(a);i<=(b);i+=(c))
#define ROFR(i,a,b,c) for(register int i=(a);i>=(b);i-=(c))
#define RevEdge(x) x^1
#define lowbit(x) x&(-x)
#define LeftChild(x) x<<1
#define RightChild(x) (x<<1)+1
#define CLOSE_IN() fclose(stdin);
#define CLOSE_OUT() fclose(stdout);
#define FILE_IN(x) freopen(x,"r",stdin);
#define FILE_OUT(x) freopen(x,"w",stdout);
#define IOS(x) std::ios::sync_with_stdio(x)
#define Dividing() printf("-----------------------------------\n");
namespace FastIO{
const int BUFSIZE = 1 << 20;
char ibuf[BUFSIZE],*is = ibuf,*its = ibuf;
char obuf[BUFSIZE],*os = obuf,*ot = obuf + BUFSIZE;
inline char getch(){
if(is == its)
its = (is = ibuf)+fread(ibuf,1,BUFSIZE,stdin);
return (is == its)?EOF:*is++;
}
inline int getint(){
int res = 0,neg = 0,ch = getch();
while(!(isdigit(ch) || ch == '-') && ch != EOF)
ch = getch();
if(ch == '-'){
neg = 1;ch = getch();
}
while(isdigit(ch)){
res = (res << 3) + (res << 1)+ (ch - '0');
ch = getch();
}
return neg?-res:res;
}
inline void flush(){
fwrite(obuf,1,os - obuf,stdout);
os = obuf;
}
inline void putch(char ch){
*os++ = ch;
if(os == ot) flush();
}
inline void putint(int res){
static char q[10];
if(res == 0) putch('0');
else if(res < 0){putch('-');res = -res;}
int top = 0;
while(res){
q[top++] = res % 10 + '0';
res /= 10;
}
while(top--) putch(q[top]);
}
inline void space(bool x){
if(!x) putch('\n');
else putch(' ');
}
}
inline void read(int &x){
int rt = FastIO::getint();
x = rt;
}
inline void print(int x,bool enter){
FastIO::putint(x);
FastIO::flush();
FastIO::space(enter);
}
/* definitions */
const int MAXN = 1e5 + 90;
struct Seg {
int l, r;
Seg (int l, int r) : l(l), r(r) {}
};
struct Que {
int x, y, k;
Que (int x, int y, int k) : x(x), y(y), k(k) {}
};
struct Tree {int l, r, tag, sum;}SMT[MAXN << 8];
std::vector<int> G[MAXN];
std::vector<Seg> a;
std::vector<Que> qaq[MAXN];
int depth[MAXN], fa[MAXN], son[MAXN], size[MAXN];
int top[MAXN], dfn[MAXN], root[MAXN], tot, n, m, cnt;
long long ans = 0;
/* functions */
inline void pushUp(int k, int l, int r) {
SMT[k].sum = (SMT[k].tag > 0) ? r - l + 1 : SMT[SMT[k].l].sum + SMT[SMT[k].r].sum;
}
inline void Modify(int &k, int l, int r, int x, int y, int val) {
if(!k) k = ++cnt;
if(x <= l && y >= r) SMT[k].tag += val;
else {
int mid = (l + r) >> 1;
if(x <= mid) Modify(SMT[k].l, l, mid, x, y, val);
if(y > mid) Modify(SMT[k].r, mid + 1, r, x, y, val);
} pushUp(k, l, r);
}
inline int unionx(int x, int y, int l, int r) {
if(!x || !y) return x | y;
SMT[x].tag += SMT[y].tag;
if(l < r) {
int mid = (l + r) >> 1;
SMT[x].l = unionx(SMT[x].l, SMT[y].l, l, mid);
SMT[x].r = unionx(SMT[x].r, SMT[y].r, mid + 1, r);
} pushUp(x, l, r);
return x;
}
inline void DFS1(int u, int f, int dep) {
depth[u] = dep; fa[u] = f; size[u] = 1;
for(auto v : G[u]) {
if(v == f) continue;
DFS1(v, u, dep + 1); size[u] += size[v];
if(size[v] > size[son[u]]) son[u] = v;
}
}
inline void DFS2(int u, int topf) {
dfn[u] = ++tot; top[u] = topf;
qaq[u].push_back((Que) {dfn[u], dfn[u], 1});
if(fa[u]) qaq[fa[u]].push_back((Que) {dfn[u], dfn[u], -1});
if(!son[u]) return ;
DFS2(son[u], topf);
for(auto v : G[u]) if(v != fa[u] && v != son[u]) DFS2(v, v);
}
inline int LCA(int x, int y) {
a.clear();
while(top[x] != top[y]) {
if(depth[top[x]] < depth[top[y]]) std::swap(x, y);
a.push_back((Seg) {dfn[top[x]], dfn[x]});
x = fa[top[x]];
} if(depth[x] > depth[y]) std::swap(x, y);
a.push_back((Seg) {dfn[x], dfn[y]});
return x;
}
inline void DFS(int u) {
for(auto v : G[u]) {
if(v == fa[u]) continue;
DFS(v); root[u] = unionx(root[u], root[v], 1, n);
}
for(auto w : qaq[u]) Modify(root[u], 1, n, w.x, w.y, w.k);
ans += SMT[root[u]].sum - 1;
}
int main(int argc,char *argv[]){
scanf("%d %d", &n, &m);
FOR(i, 2, n, 1) {
int u, v; scanf("%d %d", &u, &v);
G[u].push_back(v); G[v].push_back(u);
} DFS1(1, 0, 1); DFS2(1, 1);
while(m --> 0) {
int u, v, lca; scanf("%d %d", &u, &v);
lca = LCA(u, v);
for(auto it : a) {
qaq[u].push_back((Que) {it.l, it.r, 1});
qaq[v].push_back((Que) {it.l, it.r, 1});
if(fa[lca]) qaq[fa[lca]].push_back((Que) {it.l, it.r, -2});
}
} DFS(1); printf("%lld\n", ans >> 1);
return 0;
}

THE END