Pair Sum and Perfect Square (扫描线)
题意
给出一个长度为 $n$ 的排列 $P$。给出 $q$ 次询问,每次询问给出 $l$, $r$,回答在 $[l, r]$ 上二元组 $(i, j)$ 满足 $i \le j$ 且 $A_i + A_j$ 是完全平方数的数量。
$1 \le n,q \le 1e5$, $1 \le l \le r \le n$
思路
这道题比较关键的一点实际上是能想到对排列 $P$ 每一个元素 $A_i$ 找到和他匹配的 $A_j$ 的复杂度是 $\sqrt n$ 的。所以可以在 $n \cdot \sqrt n$ 的时间内找到每个点和它对应的元素。
区间询问容易想到扫描线。可以按顺序处理出 $[0, k]$ 上满足条件的二元组 $(i, j)$ 的数量,然后将二元组的 $i$ 存入权值线段树,这样将询问离线对右端点排序后,对于询问 $[l_i, r_i]$ 就可以用树状数组 $log$ 复杂度查询答案了。
注意这里用线段树会被卡常。
Code
#include <bits/stdc++.h>
#define all(x) x.begin(), x.end()
#define NL std::cout << '\n'
using lnt = long long;
struct BIT {
std::vector<int> t; int n;
BIT(int nn) : t(nn+5), n(nn+5) {}
void add(int x, int p) {
for (++p; p <= n; p += p & -p) {
t[p-1] += x;
}
}
int ask(int p) {
int ret = 0;
for (; p > 0; p -= p & -p) {
ret += t[p-1];
}
return ret;
}
};
signed main() {
std::ios::sync_with_stdio(false); std::cin.tie(nullptr);
int t;
std::cin >> t;
while (t--) {
int n; std::cin >> n;
std::vector<int> v(n);
for (auto &x : v) { std::cin >> x; }
int q; std::cin >> q;
std::vector<std::array<int, 3>> qry(q);
int o = 0;
for (auto &x : qry) {
std::cin >> x[0] >> x[1];
x[2] = o++;
}
std::sort(all(qry), [](auto a, auto b) {
return a[1] == b[1] ? a[0] < b[0]: a[1] < b[1];
});
BIT bit(n);
int curr = 0;
std::vector<int> pos(n*2, -1);
std::vector<int> ret(q);
for (auto [l, r, idx] : qry) {
while (curr != n && curr != r) {
for (int x = 1; x*x < n*2; ++x) {
if (v[curr] > x*x) { continue; }
int p = pos[x*x-v[curr]];
if (~p) {
bit.add(1, p);
}
}
pos[v[curr]] = curr;
curr++;
}
ret[idx] = bit.ask(r) - bit.ask(l-1);
}
for(auto x : ret) { std::cout << x; NL; }
}
return 0;;
}
Tree (点分治)
题意
给出一棵有 $n$ 个点的树,每个节点有一个颜色 $c_i \in \{'a', 'b', 'c'\}$。求满足路径上各个颜色数量相等的路径的条数。
$1 \le n \le 1e5$
思路
树上路径数量统计,考虑点分治。
对于一条路径上相应颜色数量的三元组 $(a, b, c)$,考虑合并两条路径时,对于如何快速地找到匹配的满足条件的另一条路径。一种思路是记两两颜色数量的差值,即 $(a-b, b-c, a-c)$,易证与之互补的路径的差值为 $(-(a-b), -(b-c), -(a-c))$;另一种思路是对每个颜色分别赋值 $x, y, z$,使得假定颜色数量分别为 $a,b,c$ 的情况下当且仅当 $a = b = c$ 时 $ ax + by + cz = 0 $,这样就将题意转化为统计点权和为 $0$ 的路径数量了。
Code
#include <bits/stdc++.h>
#define all(x) x.begin(), x.end()
#define NL std::cout << '\n'
using lnt = long long;
signed main() {
std::ios::sync_with_stdio(false); std::cin.tie(nullptr);
int n; std::cin >> n;
std::string s; std::cin >> s;
std::vector<lnt> val(n);
for (int i = 0; i < n; ++i) {
switch(s[i]) {
case 'a': val[i] = 1e6; break;
case 'b': val[i] = 1; break;
case 'c': val[i] = -(1e6+1); break;
}
}
std::vector<std::vector<int>> son(n);
for (int i = 1; i < n; ++i) {
int a, b; std::cin >> a >> b; --a, --b;
son[a].emplace_back(b);
son[b].emplace_back(a);
}
std::vector<char> dead(n);
std::vector<int> sz(n);
std::function<void(int, int, int, int&)> getctr = [&](int u, int fa, int n, int &ctr) {
sz[u] = 1; int mx = 0;
for (auto x : son[u]) {
if (dead[x] || fa == x) { continue; }
getctr(x, u, n, ctr);
if (~ctr) { return; }
mx = std::max(mx, sz[x]);
sz[u] += sz[x];
}
mx = std::max(mx, n - sz[u]);
if (mx * 2 <= n) {
ctr = u;
if (~fa) { sz[fa] = n - sz[u]; }
}
};
std::vector<lnt> dis(n);
std::map<lnt, int> cnt, tmp;
std::function<void(int, int)> initDis = [&](int u, int fa) {
dis[u] = dis[fa] + val[u];
tmp[dis[u]]++;
for (auto x : son[u]) {
if (x == fa || dead[x]) { continue; }
initDis(x, u);
}
};
lnt ret = 0;
std::function<void(int)> solve = [&](int u) {
for (auto v : son[u]) {
if (dead[v]) { continue; }
dis[u] = 0; initDis(v, u);
for (auto [x, c] : tmp) {
if (x + val[u] == 0) { ret += c; }
if (cnt.find(-(x + val[u])) != cnt.end()) { ret += c*cnt[-(x + val[u])]; }
}
for (auto [x, c] : tmp) {
cnt[x] += c;
}
tmp.clear();
}
cnt.clear();
};
std::function<void(int u)> divide = [&](int u) {
dead[u] = 1;
solve(u);
for (auto v : son[u]) {
if (dead[v]) { continue; }
int ctr = -1;
getctr(v, u, sz[v], ctr);
divide(ctr);
}
};
int ctr = -1;
getctr(0, -1, n, ctr);
divide(ctr);
std::cout << ret;
return 0;
}