BZOJ 5293: [BJOI2018]求和 – [树链剖分]


题目传送门(BZOJ)
>原文链接<

题解 :

注意到K很小, 又因为深度不会超过n, 所以我们把1\sim n的所有1\sim 50次方都预处理好

dis[i][k]为在k次方意义下的, 点i的深度前缀和.

那么对于一次询问(x, y, k) 答案就是

dis[x][k]+dis[y][k] – dis[LCA(x, y)][k] – dis[fa[LCA(x, y)]][k]

另外呢, 由于C++资瓷把下标写到外面, 为了营造喜剧效果, 我所有的数组无论嵌套还是多维的, 都下标写到了外面

鬼畜代码, 阅读愉快 !

代码 :

#include 
#include 
#include 
#include 
#include 
#define ll long long
#define max(a, b) (a>b?a:b)
using namespace std;
const int N = 310000, K = 60;
const int mod = 998244353;
inline char nc() {
    static char buf[100000], *p1, *p2;
    return p1==p2&&(p2=(p1=buf)+fread(buf, 1, 100000, stdin), p1==p2)?EOF:*p1++;
}
int rd() {
    int x = 0; char c = nc();
    while(!isdigit(c)) c = nc();
    while(isdigit(c)) x = (x<<1)+(x<<3)+(c^48), c = nc();
    return x;
}
int dep[N], top[N], fa[N], son[N], siz[N], maxinum, dis[N][K];
int pwr[N][K], n;
struct Edge {
    int to;
    Edge *next;
}*h[N], e[N<<1];
void Add_Edge(int u, int v) {
    static int _; Edge *
    tmp = &(++_)[e]; tmp->to = v; tmp->next = u[h]; h[u] = tmp;
    tmp = &(++_)[e]; tmp->to = u; tmp->next = h[v]; v[h] = tmp;
}
void dfs(int p) {
    dep[p] = p[fa][dep] + 1;
    for(int i = 1; i <= 50; i ++ ) i[p[dis]] = (i[p[fa][dis]] + pwr[dep[p]-1][i])%mod;
    p[siz] = 1;
    maxinum = max(maxinum, p[dep]);
    for(Edge *i = p[h]; i; i = i->next) {
        int to = i->to;
        if(to != p[fa]) {
            to[fa] = p;
            dfs(to);
            p[siz] += to[siz];
            if(to[siz] > p[son][siz]) p[son] = to;    
        }
    }
}
void dfs(int p, int t) {
    p[top] = t;
    if(p[son]) dfs(son[p], t);
    for(Edge *i = p[h];i;i=i->next) {
        int to = i -> to;
        if(to!=p[fa]&&to!=p[son]) {
            dfs(to, to);
        }
    }
}
void Init() {
    n = rd();
    for(int i = 2; i <= n; i ++ ) {
        int u = rd(), v = rd();
        Add_Edge(u, v);
    }
    for(int i = 1; i <= n; i ++ ) {
        0[i[pwr]] = 1;
        for(int j = 1; j <= 50; j ++ ) {
            j[i[pwr]] = ((ll)(j-1)[i[pwr]] * i)%mod;
        }
    }
    dfs(1);
    dfs(1, 1);
}
void solve(int x, int y, int k) {
    int ret = 0;
    ret += (k[x[dis]]+ k[y[dis]]) %mod;
    while(x[top] != y[top]) {
        if(x[top][dep] < y[top][dep]) swap(x, y);
            x = x[top][fa];
    }
    int ans = (x[dep] < y[dep]) ? x : y;
    ret -= k[ans[dis]];
    ret -= k[ans[fa][dis]];
    ((ret%=mod)+=mod)%=mod;
    printf("%d\n", ret);
}
int main() {
    Init(); int q = rd();
    while(q -- ) {
        int x = rd(), y = rd(), k = rd();
        solve(x, y, k);
    }
}

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注