题目传送门(BZOJ)
>原文链接<

题解 :

注意到K很小, 又因为深度不会超过n, 所以我们把1\sim n的所有1\sim 50次方都预处理好

dis[i][k]为在k次方意义下的, 点i的深度前缀和.

那么对于一次询问(x, y, k) 答案就是

dis[x][k]+dis[y][k] – dis[LCA(x, y)][k] – dis[fa[LCA(x, y)]][k]

另外呢, 由于C++资瓷把下标写到外面, 为了营造喜剧效果, 我所有的数组无论嵌套还是多维的, 都下标写到了外面

鬼畜代码, 阅读愉快 !

代码 :

#include <cstdio>#include <cstring>#include <algorithm>#include <iostream>#include <cctype>#define ll long long#define max(a, b) (a>b?a:b)using namespace std;const int N = 310000, K = 60;const int mod = 998244353;inline char nc() { static char buf[100000], *p1, *p2; return p1==p2&&(p2=(p1=buf)+fread(buf, 1, 100000, stdin), p1==p2)?EOF:*p1++;}int rd() { int x = 0; char c = nc(); while(!isdigit(c)) c = nc(); while(isdigit(c)) x = (x<<1)+(x<<3)+(c^48), c = nc(); return x;}int dep[N], top[N], fa[N], son[N], siz[N], maxinum, dis[N][K];int pwr[N][K], n;struct Edge { int to; Edge *next;}*h[N], e[N<<1];void Add_Edge(int u, int v) { static int _; Edge * tmp = &(++_)[e]; tmp->to = v; tmp->next = u[h]; h[u] = tmp; tmp = &(++_)[e]; tmp->to = u; tmp->next = h[v]; v[h] = tmp;}void dfs(int p) { dep[p] = p[fa][dep] + 1; for(int i = 1; i <= 50; i ++ ) i[p[dis]] = (i[p[fa][dis]] + pwr[dep[p]-1][i])%mod; p[siz] = 1; maxinum = max(maxinum, p[dep]); for(Edge *i = p[h]; i; i = i->next) { int to = i->to; if(to != p[fa]) { to[fa] = p; dfs(to); p[siz] += to[siz]; if(to[siz] > p[son][siz]) p[son] = to; } }}void dfs(int p, int t) { p[top] = t; if(p[son]) dfs(son[p], t); for(Edge *i = p[h];i;i=i->next) { int to = i -> to; if(to!=p[fa]&&to!=p[son]) { dfs(to, to); } }}void Init() { n = rd(); for(int i = 2; i <= n; i ++ ) { int u = rd(), v = rd(); Add_Edge(u, v); } for(int i = 1; i <= n; i ++ ) { 0[i[pwr]] = 1; for(int j = 1; j <= 50; j ++ ) { j[i[pwr]] = ((ll)(j-1)[i[pwr]] * i)%mod; } } dfs(1); dfs(1, 1);}void solve(int x, int y, int k) { int ret = 0; ret += (k[x[dis]]+ k[y[dis]]) %mod; while(x[top] != y[top]) { if(x[top][dep] < y[top][dep]) swap(x, y); x = x[top][fa]; } int ans = (x[dep] < y[dep]) ? x : y; ret -= k[ans[dis]]; ret -= k[ans[fa][dis]]; ((ret%=mod)+=mod)%=mod; printf("%d\n", ret);}int main() { Init(); int q = rd(); while(q -- ) { int x = rd(), y = rd(), k = rd(); solve(x, y, k); }}