题解 P7518 【[省选联考 2021 A/B 卷] 宝石】
Lice
2021-04-12 08:52:23
Updated 2021/4/13:更新了代码(原代码可能无法处理 $s=o$ 并且 $W_o=1$ 的情况),以及加了一些细节说明。
### Description
给定一棵树,$n$ 个顶点,每个点有一个宝石,类型为 $W_i$,约定 $W_i\in [1,m]$。你有一个收集器,可以收集至多 $c$ 个宝石,并且收集顺序必须为 $P_1,P_2,\cdots, P_c$,其中 $P_1\sim P_c$ 互不相同。先有 $q$ 次询问,每次询问一条有向路径 $s_i\to t_i$,求依次最多可以收集多少宝石。
### Hint
$1\le n, q\le 2\times 10^5$,$1\le c\le m\le 5\times 10^4$
### Solution
不知道为什么,我遇见这种树上路径问题就很自然的想到点分治,于是考场上就有了这个神奇做法。
为方便,将 $W_x$ 的含义改为原 $W_x$ 在 $P$ 中的下标,换句话说,即“排名”。
首先离线询问,对于当前分治的连通块,处理出所有包含于块内、经过分治中心的询问的答案。注意到将路径按分治中心 $o$ 分为两部分后,两部分本质不相同(显然吧),因此分别考虑 $s\to o,o\to t$ 两部分。
对于 $s\to o$ 部分,我们希望选出的宝石尽可能多,理由不仅仅是最大化答案,还因为我们可以有余地去和 $o\to t$ 部分对接。那么不难对每个点求出其到 $o$ 的路径上最多选多少宝石,具体的,求出路径上第一个 $P_1$ 出现的位置,对于每个 $x$ 求 $W_x +1$ 对应的第一个位置即可。
对于 $o\to t$ 部分,我们希望可以与前一部分对接,那必然是从 $t$ 开始,选逐渐递减的几个,并且最后那个尽可能小。用类似上面的方法也可以求出,注意如果上面包含了 $o$ 这里就最好刻意避开。
最后考虑在后半部分中任何处理询问。很显然答案有可二分性,考虑二分答案。如何检验一个 $mid$ 是否可达?不妨先在 $t\to o$ 的路径上找到第一个 $mid$ 的位置(如果找不到当然爪巴了),然后在这个位置往上递减(每次减一)地跳,看看能跳到多小,记其为 $y$。令 $s\to o$ 路径上最多选出 $x$ 个,那么若 $x+1\ge y$,说明能接上,当前这个 $mid$ 是可行的。
最后复杂度 $O(n\log n+q\log m)$,点分一只 $\log$,后面对于每个询问做一次 $O(\log m)$ 的二分答案。
### Implementation
为什么这要一个单独说明呢?因为我在考场上细节没想清楚,导致写的很谔谔(绝对没有下面代码那么简洁),花了 3 个多小时,所以清晰的实现思路同算法一样重要。
下面是我的实现,不一定是最简洁的。
首先是传询问:每一次分治中都要对询问分类:跨越分治中心 $o$ 与否。染色法是一个好东西:我们对分治块中每棵子树染上不同的颜色,如果 $s,t$ 的颜色相同,那就是下一层递归该干的事情了(注意特判 $s=t=o$。特别是 $s=o$,需要加判断 $W_o$ 是否 $=1$)。
然后是求出一条 $x\to o$ 的路径求出最多选几个宝石。一波 Dfs 求出上面第一个 $1$ 的位置,第一个 $W_x+1$ 的位置(可以借助可撤销数组完成),通过这两个再一波 Dfs 求出每个点向上选出一个值域连续段(不一定从一开始)的长度即可,先往上找到一个 $1$,然后连着走后继。
$o\to t$ 部分异曲同工,注意最后二分答案的初始左边界为,前半部分的答案 $+1$。不 $+1$ 过不了大样例?这个我也不清楚。
### Code
```cpp
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <vector>
inline int read() {
int x = 0; char c; while (!isdigit(c = getchar()));
do x = (x << 1) + (x << 3) + c - '0'; while (isdigit(c = getchar()));
return x;
}
const int N = 2e5 + 5;
const int M = 5e4 + 5;
int n, m, c, W[N];
int rank[N];
std::vector<int> adj[N];
struct revArray {
int dat[M];
std::vector<std::pair<int, int> > his;
inline void edit(int x, int v) {
his.emplace_back(x, dat[x]);
dat[x] = v;
}
inline int operator [] (const int& x) const {
return dat[x];
}
inline void back() {
auto last = his.back();
dat[last.first] = last.second;
his.pop_back();
}
} rec;
int q, ans[N];
struct request {
int s, t, idx;
};
int maxp[N], siz[N], root;
bool centr[N];
int get_size(int x, int f) {
siz[x] = 1;
for (auto y : adj[x]) if (y != f && !centr[y])
siz[x] += get_size(y, x);
return siz[x];
}
void get_centr(int x, int f, int t) {
maxp[x] = 0;
for (auto y : adj[x]) if (y != f && !centr[y])
get_centr(y, x, t), maxp[x] = std::max(maxp[x], siz[y]);
maxp[x] = std::max(maxp[x], t - siz[x]);
if (maxp[x] < maxp[root]) root = x;
}
int color[N];
void set_color(int x, int f, int c) {
color[x] = c;
for (auto y : adj[x]) if (y != f && !centr[y])
set_color(y, x, c);
}
int nxt[N], fir[N], pre[N], inc[N], dec[N];
void get_nxt_fir(int x, int f) {
rec.edit(W[x], x);
nxt[x] = rec[W[x] + 1], fir[x] = rec[1];
for (auto y : adj[x]) if (y != f && !centr[y])
get_nxt_fir(y, x);
rec.back();
}
void get_inc(int x, int f) {
inc[x] = nxt[x] ? inc[nxt[x]] : W[x];
for (auto y : adj[x]) if (y != f && !centr[y])
get_inc(y, x);
}
void get_pre(int x, int f) {
rec.edit(W[x], x);
if (W[x]) pre[x] = rec[W[x] - 1];
for (auto y : adj[x]) if (y != f && !centr[y])
get_pre(y, x);
rec.back();
}
void get_dec(int x, int f) {
dec[x] = pre[x] ? dec[pre[x]] : W[x];
for (auto y : adj[x]) if (y != f && !centr[y])
get_dec(y, x);
}
std::vector<request> subreq[N];
std::vector<request> tosol[N];
void get_ans(int x, int f) {
if (x != f) rec.edit(W[x], x);
for (auto qry : tosol[x]) {
int worst = fir[qry.s] ? inc[fir[qry.s]] : 0;
int lb = worst + 1, ub = c;
while (lb <= ub) {
int mid = (lb + ub) >> 1;
if (!rec[mid] || dec[rec[mid]] - 1 > worst) ub = mid - 1;
else lb = mid + 1, ans[qry.idx] = mid;
}
ans[qry.idx] = std::max(ans[qry.idx], worst);
}
for (auto y : adj[x]) if (y != f && !centr[y])
get_ans(y, x);
if (x != f) rec.back();
}
void clear_all(int x, int f) {
pre[x] = nxt[x] = fir[x] = 0;
color[x] = inc[x] = dec[x] = 0;
tosol[x].clear();
for (auto y : adj[x]) if (y != f && !centr[y])
clear_all(y, x);
}
void solve(int x, std::vector<request>& req) {
if (req.empty()) return;
maxp[root = 0] = N;
get_centr(x, 0, get_size(x, 0));
centr[root] = 1;
std::vector<request> cur;
for (auto y : adj[root]) if (!centr[y])
set_color(y, root, y);
for (auto qry : req)
if (color[qry.s] == color[qry.t]) {
if (!color[qry.s]) ans[qry.idx] = (W[root] == 1);
else subreq[color[qry.s]].push_back(qry);
} else cur.push_back(qry);
rec.edit(W[root], root);
fir[root] = W[root] == 1 ? root : 0;
for (auto y : adj[root]) if (!centr[y])
get_nxt_fir(y, root);
rec.back();
inc[root] = W[root];
for (auto y : adj[root]) if (!centr[y])
get_inc(y, root);
for (auto y : adj[root]) if (!centr[y])
get_pre(y, root);
for (auto y : adj[root]) if (!centr[y])
get_dec(y, root);
for (auto qry : cur)
tosol[qry.t].push_back(qry);
get_ans(root, root);
clear_all(root, 0);
std::vector<std::vector<request> > copy;
for (auto y : adj[root]) if (!centr[y])
copy.push_back(subreq[y]), subreq[y].clear();
auto it = copy.begin();
for (auto y : adj[root]) if (!centr[y])
solve(y, *it), ++it;
}
signed main() {
n = read(), m = read(), c = read();
for (int i = 1; i <= c; i++) rank[read()] = i;
for (int i = 1; i <= n; i++) W[i] = rank[read()];
for (int i = 1; i < n; i++) {
int u = read(), v = read();
adj[u].push_back(v), adj[v].push_back(u);
}
std::vector<request> req(q = read());
for (int i = 0; i < q; i++)
req[i].s = read(), req[i].t = read(), req[i].idx = i + 1;
solve(1, req);
for (int i = 1; i <= q; i++)
printf("%d\n", ans[i]);
return 0;
}
```