「牛客 1103B」路径计数机

是道好题

可惜卡常数

思路:暴力枚举 (a,b)(a,b),快速统计 (c,d)(c,d)

定义 f(u)f(u) 为以 uuLCA\operatorname{LCA}(c,d)(c,d) 个数

tsf(u)=f(v)tsf(u)=\sum f(v)vv 在子树 uu

csf(u)csf(u) 11uu 路径上所有点的 f\sum f

(c,d)(c,d) 总数量记为 sqsq

经过 uuuu 的父节点点对 (c,d)(c,d) 数量为 vq(u)vq(u),可用树上差分求出

考虑对于点对 (a,b)(a,b),记 l=LCA(a,b)l=\operatorname{LCA}(a,b)

LCA(c,d)\operatorname{LCA}(c,d) 在子树 ll 中,数量为

tsf(l)(csf(a)csf(l))(csf(b)csf(l))f(l) tsf(l)-(csf(a)-csf(l))-(csf(b)-csf(l))-f(l)

LCA(c,d)\operatorname{LCA}(c,d) 不在子树 ll 中,数量为

sqtsf(l)vq(l) sq-tsf(l)-vq(l)
#include <bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC optimize("inline","fast-math","unroll-loops","no-stack-protector")
#define rep(i, l, r) for (register int i = (l); i <= (r); ++i)
#define per(i, l, r) for (register int i = (l); i >= (r); --i)
using std::make_pair; using std::pair; typedef pair<int, int> pii;
typedef long long ll; typedef unsigned int ui; typedef unsigned long long ull;

struct IO {
#define MAXSIZE (1 << 20)
#define isdigit(x) (x >= '0' && x <= '9')
char buf[MAXSIZE], *p1, *p2;
char pbuf[MAXSIZE], *pp;
#if DEBUG
#else
IO() : p1(buf), p2(buf), pp(pbuf) {}
~IO() { fwrite(pbuf, 1, pp - pbuf, stdout); }
#endif
inline char gc() {
#if DEBUG //调试,可显示字符
return getchar();
#endif
if (p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin);
return p1 == p2 ? -1 : *p1++;
}
inline bool blank(char ch) { return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t'; }
template <class T>
inline void read(T &x) {
register double tmp = 1;
register bool sign = 0;
x = 0;
register char ch = gc();
for (; !isdigit(ch); ch = gc())
if (ch == '-') sign = 1;
for (; isdigit(ch); ch = gc()) x = x * 10 + (ch - '0');
if (ch == '.')
for (ch = gc(); isdigit(ch); ch = gc()) tmp /= 10.0, x += tmp * (ch - '0');
if (sign) x = -x;
}
inline void read(char *s) {
register char ch = gc();
for (; blank(ch); ch = gc())
;
for (; !blank(ch); ch = gc()) *s++ = ch;
*s = 0;
}
inline void read(char &c) {
for (c = gc(); blank(c); c = gc())
;
}
inline void push(const char &c) {
#if DEBUG //调试,可显示字符
putchar(c);
#else
if (pp - pbuf == MAXSIZE) fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf;
*pp++ = c;
#endif
}
template <class T>
inline void write(T x) {
if (x < 0) x = -x, push('-'); // 负数输出
static T sta[35];
T top = 0;
do {
sta[top++] = x % 10, x /= 10;
} while (x);
while (top) push(sta[--top] + '0');
}
inline void write(const char *s) {
while (*s != '\0') push(*(s++));
}
template <class T>
inline void write(T x, char lastChar) {
write(x), push(lastChar);
}
} io;

const int N = 3010;

int n, p, q, lc[N][N];
std::vector<int> g[N];

int dep[N], fa[13][N];
void dfs_lca(int u){
dep[u] = dep[fa[0][u]] + 1;
for(auto v : g[u]){
if (v == fa[0][u]) continue;
fa[0][v] = u;
dfs_lca(v);
}
}

int lca(int u, int v){
if (dep[u] < dep[v]) std::swap(u, v);
int delta = dep[u] - dep[v];
rep(k, 0, 12) if (delta&(1<<k)) u = fa[k][u];
if (u == v) return u;
per(k, 12, 0)
if (fa[k][u] != fa[k][v])
u = fa[k][u], v = fa[k][v];
return fa[0][u];
}

ll f[N], tsf[N], csf[N], dq[N], vq[N];

void calc(int u){
csf[u] = csf[fa[0][u]] + f[u];
vq[u] = dq[u];
tsf[u] = f[u];
for(auto v : g[u]){
if (v == fa[0][u]) continue;
calc(v);
tsf[u] += tsf[v];
vq[u] += vq[v];
}
}

int main() {
#ifdef LOCAL
freopen("input", "r", stdin);
#endif
io.read(n), io.read(p), io.read(q);
int u, v;
rep(i, 2, n){
io.read(u), io.read(v);
g[u].push_back(v), g[v].push_back(u);
}
dfs_lca(1);
rep(k, 1, 12) rep(i, 1, n) fa[k][i] = fa[k-1][fa[k-1][i]];
ll sq = 0;
rep(i, 1, n){
rep(j, 1, n){
int l = lca(i, j);
lc[i][j] = l;
int d = dep[i] + dep[j] - dep[l] * 2;
if (d == q){
sq++;
f[l]++;
dq[i]++, dq[j]++, dq[l] -= 2;
}
}
}
calc(1);
ll ans = 0;
rep(i, 1, n){
rep(j, 1, n){
int l = lc[i][j];
int d = dep[i] + dep[j] - dep[l] * 2;
if (d == p){
ans += tsf[l]-(csf[i]-csf[l])-(csf[j]-csf[l])-f[l];
ans += sq - tsf[l] - vq[l];
}
}
}
io.write(ans);
return 0;
}