杭电多校 (6) (2023-09-22)

Pair Sum and Perfect Square (扫描线)

题意

给出一个长度为 $n$ 的排列 $P$。给出 $q$ 次询问,每次询问给出 $l$, $r$,回答在 $[l, r]$ 上二元组 $(i, j)$ 满足 $i \le j$ 且 $A_i + A_j$ 是完全平方数的数量。

$1 \le n,q \le 1e5$, $1 \le l \le r \le n$

思路

这道题比较关键的一点实际上是能想到对排列 $P$ 每一个元素 $A_i$ 找到和他匹配的 $A_j$ 的复杂度是 $\sqrt n$ 的。所以可以在 $n \cdot \sqrt n$ 的时间内找到每个点和它对应的元素。

区间询问容易想到扫描线。可以按顺序处理出 $[0, k]$ 上满足条件的二元组 $(i, j)$ 的数量,然后将二元组的 $i$ 存入权值线段树,这样将询问离线对右端点排序后,对于询问 $[l_i, r_i]$ 就可以用树状数组 $log$ 复杂度查询答案了。

注意这里用线段树会被卡常。

Code
#include <bits/stdc++.h>

#define all(x)      x.begin(), x.end()
#define NL          std::cout << '\n'

using lnt = long long;

struct BIT {
  std::vector<int> t; int n;
  BIT(int nn) : t(nn+5), n(nn+5) {}

  void add(int x, int p) {
    for (++p; p <= n; p += p & -p) {
      t[p-1] += x;
    }
  }

  int ask(int p) {
    int ret = 0;
    for (; p > 0; p -= p & -p) {
      ret += t[p-1];
    }
    return ret;
  }
};

signed main() {
  std::ios::sync_with_stdio(false); std::cin.tie(nullptr);
  int t;
  std::cin >> t;
  while (t--) {
    int n; std::cin >> n;
    std::vector<int> v(n);
    for (auto &x : v) { std::cin >> x; }
    int q; std::cin >> q;
    std::vector<std::array<int, 3>> qry(q);
    int o = 0;
    for (auto &x : qry) {
      std::cin >> x[0] >> x[1];
      x[2] = o++;
    }
    std::sort(all(qry), [](auto a, auto b) {
      return a[1] == b[1] ? a[0] < b[0]: a[1] < b[1];
    });
    BIT bit(n);
    int curr = 0;
    std::vector<int> pos(n*2, -1);
    std::vector<int> ret(q);
    for (auto [l, r, idx] : qry) {
      while (curr != n && curr != r) {
        for (int x = 1; x*x < n*2; ++x) {
          if (v[curr] > x*x) { continue; }
          int p = pos[x*x-v[curr]];
          if (~p) {
            bit.add(1, p);
          }
        }
        pos[v[curr]] = curr;
        curr++;
      }
      ret[idx] = bit.ask(r) - bit.ask(l-1);
    }
    for(auto x : ret) { std::cout << x; NL; }
  }
  return 0;;
}

Tree (点分治)

题意

给出一棵有 $n$ 个点的树,每个节点有一个颜色 $c_i \in \{'a', 'b', 'c'\}$。求满足路径上各个颜色数量相等的路径的条数。

$1 \le n \le 1e5$

思路

树上路径数量统计,考虑点分治。

对于一条路径上相应颜色数量的三元组 $(a, b, c)$,考虑合并两条路径时,对于如何快速地找到匹配的满足条件的另一条路径。一种思路是记两两颜色数量的差值,即 $(a-b, b-c, a-c)$,易证与之互补的路径的差值为 $(-(a-b), -(b-c), -(a-c))$;另一种思路是对每个颜色分别赋值 $x, y, z$,使得假定颜色数量分别为 $a,b,c$ 的情况下当且仅当 $a = b = c$ 时 $ ax + by + cz = 0 $,这样就将题意转化为统计点权和为 $0$ 的路径数量了。

Code
#include <bits/stdc++.h>

#define all(x)      x.begin(), x.end()
#define NL          std::cout << '\n'

using lnt = long long;

signed main() {
  std::ios::sync_with_stdio(false); std::cin.tie(nullptr);
  int n; std::cin >> n;
  std::string s; std::cin >> s;
  std::vector<lnt> val(n);
  for (int i = 0; i < n; ++i) {
    switch(s[i]) {
      case 'a': val[i] = 1e6; break;
      case 'b': val[i] = 1; break;
      case 'c': val[i] = -(1e6+1); break;
    }
  }
  std::vector<std::vector<int>> son(n);
  for (int i = 1; i < n; ++i) {
    int a, b; std::cin >> a >> b; --a, --b;
    son[a].emplace_back(b);
    son[b].emplace_back(a);
  }
  std::vector<char> dead(n);

  std::vector<int> sz(n);
  std::function<void(int, int, int, int&)> getctr = [&](int u, int fa, int n, int &ctr) {
    sz[u] = 1; int mx = 0;
    for (auto x : son[u]) {
      if (dead[x] || fa == x) { continue; }
      getctr(x, u, n, ctr);
      if (~ctr) { return; }
      mx = std::max(mx, sz[x]);
      sz[u] += sz[x];
    }
    mx = std::max(mx, n - sz[u]);
    if (mx * 2 <= n) {
      ctr = u;
      if (~fa) { sz[fa] = n - sz[u]; }
    }
  };

  std::vector<lnt> dis(n);
  std::map<lnt, int> cnt, tmp;
  std::function<void(int, int)> initDis = [&](int u, int fa) {
    dis[u] = dis[fa] + val[u];
    tmp[dis[u]]++;
    for (auto x : son[u]) {
      if (x == fa || dead[x]) { continue; }
      initDis(x, u);
    }
  };

  lnt ret = 0;
  std::function<void(int)> solve = [&](int u) {
    for (auto v : son[u]) {
      if (dead[v]) { continue; }
      dis[u] = 0; initDis(v, u);
      for (auto [x, c] : tmp) {
        if (x + val[u] == 0) { ret += c; }
        if (cnt.find(-(x + val[u])) != cnt.end()) { ret += c*cnt[-(x + val[u])]; }
      }
      for (auto [x, c] : tmp) {
        cnt[x] += c;
      }
      tmp.clear();
    }
    cnt.clear();
  };


  std::function<void(int u)> divide = [&](int u) {
    dead[u] = 1;
    solve(u);
    for (auto v : son[u]) {
      if (dead[v]) { continue; }
      int ctr = -1;
      getctr(v, u, sz[v], ctr);
      divide(ctr);
    }
  };

  int ctr = -1;
  getctr(0, -1, n, ctr);
  divide(ctr);
  std::cout << ret;
  return 0;
}