题目链接:HDU-6547 Tree
题意
wls 有三棵树,树上每个节点都有一个值 $a_i$,现在有 2 种操作:
1. 将一条链上的所有节点的值开根号向下取整;
2. 求一条链上值的和;
链的定义是两点之间的最短路。
思路
树链剖分裸题,区间开根号可用线段树做,利用 $10^9$ 范围内的数经过少数几次开根号之后就会达到 1,标记线段树区间最大值,若为 1 则无需再往下更新。
树链剖分传送门:https://www.cnblogs.com/kangkang-/p/8486150.html
代码实现
#include <stdio.h>
#include <iostream>
#include <cmath>
#define REP(i, a, b) for (int i = a; i <= b; i++)
using namespace std;
typedef long long LL;
const double esp = 1e-8;
const int MAXN = 110000;
struct Node {
int to, next;
} edg[MAXN<<1];
struct segmentTree {
int left, right;
LL sum, maxx;
} tree[MAXN<<2];
int head[MAXN], siz[MAXN], top[MAXN], hson[MAXN], dep[MAXN], fa[MAXN], id[MAXN], rnk[MAXN];
int N, M, R, A[MAXN], idx = 0, dfs_cnt = 0;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return x * f;
}
inline void adde(int u, int v) {
edg[++idx].to = v; edg[idx].next = head[u]; head[u] = idx;
}
void dfs1(int u, int father, int depth) {
dep[u] = depth;
fa[u] = father;
siz[u] = 1;
for (int i = head[u]; i; i = edg[i].next) {
int v = edg[i].to;
if (v != fa[u]) {
dfs1(v, u, depth + 1);
siz[u] += siz[v];
if (hson[u] == -1 || siz[v] > siz[hson[u]]) hson[u] = v;
}
}
}
void dfs2(int u, int t) {
id[u] = ++dfs_cnt; rnk[dfs_cnt] = u; top[u] = t;
if (!hson[u]) return ;
dfs2(hson[u], t);
for (int i = head[u]; i; i = edg[i].next) {
int v = edg[i].to;
if (v != hson[u] && v != fa[u]) dfs2(v, v);
}
}
void buildtree(int i, int l, int r) {
tree[i].left = l; tree[i].right = r;
if (l == r) tree[i].sum = tree[i].maxx = A[rnk[l]];
else {
int mid = (l + r) >> 1;
buildtree(i << 1, l , mid);
buildtree(i << 1 | 1, mid + 1, r);
tree[i].sum = tree[i<<1].sum + tree[i<<1|1].sum;
tree[i].maxx = max(tree[i<<1].maxx, tree[i<<1|1].maxx);
}
}
void update(int i, int x, int y) {
if (tree[i].left > y || tree[i].right < x) return ;
if (tree[i].left == tree[i].right) {
tree[i].maxx = sqrt(tree[i].maxx) + esp;
tree[i].sum = sqrt(tree[i].sum) + esp;
return ;
}
if (tree[i].maxx == 1) return ;
int l = i << 1, r = i << 1 | 1;
update(l, x, y);
update(r, x, y);
tree[i].sum = tree[l].sum + tree[r].sum;
}
LL query(int i, int x, int y) {
int l = i << 1, r = i << 1 | 1;
if (x <= tree[i].left && tree[i].right <= y) return tree[i].sum;
if (tree[i].left > y || tree[i].right < x) return 0;
return query(l, x, y) + query(r, x, y);
}
void update_path(int u, int v) {
int tu = top[u], tv = top[v];
while (tu != tv) {
if (dep[tu] < dep[tv]) swap(u, v), swap(tu, tv);
update(1, id[tu], id[u]);
u = fa[tu], tu = top[u];
}
if (dep[u] < dep[v]) swap(u, v);
update(1, id[v], id[u]);
}
LL query_path(int u, int v) {
LL res = 0;
int tu = top[u], tv = top[v];
while (tu != tv) {
if (dep[tu] < dep[tv]) swap(u, v), swap(tu, tv);
res += query(1, id[tu], id[u]);
u = fa[tu], tu = top[u];
}
if (dep[u] < dep[v]) swap(u, v);
return res + query(1, id[v], id[u]);
}
int main() {
N = read(), M = read(), R = 1;
REP(i, 1, N) A[i] = read();
REP(i, 2, N) {
int u = read(), v = read();
adde(u, v); adde(v, u);
}
dfs1(R, 0, 1);
dfs2(R, R);
buildtree(1, 1, N);
while (M--) {
int opt = read();
switch (opt) {
case 0: {
int x = read(), y = read();
LL z;
update_path(x, y);
break;
}
case 1: {
int x = read(), y = read();
printf("%lld\n", query_path(x, y));
break;
}
}
}
return 0;
}
知识兔View Code