BZOJ 3224: 普通平衡树 – [Treap/Splay/线段树][模版]

题目传送门(BZOJ)
>原文链接<

题解 :

平衡树模版题。

提供两个平衡树模版,分别是 Treap 和 Splay 的

upd : 补充一发权值线段树写的超简单板子, NOIP要用就用它了!

推荐 Splay 的学习链接 : Splay-Clove_Unique

推荐 Treap 的学习链接 :洛谷上的本题题解

代码 :

Segment Tree :

#include 
#include 
using namespace std;
const int N = 100100;
const int INF = 1e7;
char buf[100000], *p1, *p2;
char pbuf[1000000], *pp = pbuf;
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 100000, stdin), p1==p2)?EOF:*p1++)
#define pc(c) pp-pbuf == 1000000 ? fwrite(pbuf, 1, 1000000, stdout), *pp++=c : *pp++=c
inline int rd() {
    int x = 0, f = 1; char ch = nc();
    while(!isdigit(ch)) {if(ch == '-') f = -1; ch = nc();}
    while( isdigit(ch)) x = (((x<<2)+x)<<1)+(ch^48), ch = nc();
    return x*f;
}
inline void wt(int x) {
    static int stk[10], top; 
    if(x < 0) pc('-'), x = -x;
    do stk[++top] = x%10; while(x/=10);
    while(top) pc(stk[top--]+48);
    pc('\n');
}
int t[N*20], ls[N*20], rs[N*20], root, cnt;
void change(int l, int r, int &p, int x, int c) {
    if(!p) p = ++cnt; t[p] += c;
    if(l == r) return;
    int mid = (l+r) >> 1;
    if(x <= mid) change(l, mid, ls[p], x, c);
    else change(mid+1, r, rs[p], x, c);
}
int qrank(int l, int r, int p, int k) {
    if(l == r) return 0;
    int mid = (l+r) >> 1;
    if(k <= mid) return qrank(l, mid, ls[p], k);
    else return t[ls[p]]+qrank(mid+1, r, rs[p], k);
}
int qnum(int l, int r, int p, int k) {
    if(l == r) return l;
    int mid = (l+r) >> 1;
    if(t[ls[p]] >= k) return qnum(l, mid, ls[p], k);
    else return qnum(mid+1, r, rs[p], k-t[ls[p]]);
}
int main() {
    int T = rd();
    while(T -- ) switch(rd()) {
        case 1 : change(-INF, INF, root, rd(), 1); break;
        case 2 : change(-INF, INF, root, rd(), -1); break;
        case 3 : wt(qrank(-INF, INF, root, rd())+1); break;
        case 4 : wt(qnum(-INF, INF, root, rd())); break;
        case 5 : wt(qnum(-INF, INF, root, qrank(-INF, INF, root, rd()))); break;
        case 6 : wt(qnum(-INF, INF, root, qrank(-INF, INF, root, rd()+1)+1)); break;
    }
    fwrite(pbuf, 1, pp-pbuf, stdout);
}

Treap :

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
const int N = 1e5+10;
int t1, t2;
inline char nc() {
    static char buf[100000], *p1=buf, *p2=buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf,1,100000,stdin),p1 == p2)?EOF:*p1++ ;
}
inline int read() {
    int x = 0 , f = 1; char ch = nc();
    while(!isdigit(ch)) {if(ch == '-')f = -1; ch = nc();}
    while(isdigit(ch)) {x = x * 10 + ch - '0'; ch = nc();}
    return x*f;
}
struct node {
    int ls,rs,num;
    int siz,sum;
    int key;
}tr[N];
int cnt;
int root = 0 ;
void upd(int &p) {
    int &ls = tr[p].ls, &rs = tr[p].rs;
    tr[p].siz = tr[ls].siz + tr[rs].siz + tr[p].sum;
    return;
}
void lturn(int &p) {
    int t = tr[p].rs;
    tr[p].rs = tr[t].ls;
    tr[t].ls = p;
    tr[t].siz = tr[p].siz;
    upd(p);
    p = t;
}
void rturn(int &p) {
    int t = tr[p].ls;
    tr[p].ls = tr[t].rs;
    tr[t].rs = p;
    tr[t].siz = tr[p].siz;
    upd(p);
    p = t;
}
void insert(int &p,int x) {
    if(!p) {
        p = ++cnt;
        tr[p].num = x;
        tr[p].siz = 1;
        tr[p].sum ++;
        tr[p].key = rand();
        return;
    }
    tr[p].siz ++;
    int &ls = tr[p].ls,&rs = tr[p].rs;
    if(x == tr[p].num) return tr[p].sum ++, void();
    else if(x < tr[p].num) {
        insert(ls , x);
        if(tr[ls].key < tr[p].key)
            rturn(p);
    }
    else if(x > tr[p].num) {
        insert(rs , x);
        if(tr[rs].key < tr[p].key)
            lturn(p);
    }
}
void del(int &p,int x) {
    if(!p)return;
    int &ls = tr[p].ls , &rs = tr[p].rs;
    if(x == tr[p].num) {
        if(tr[p].sum > 1) {
            tr[p].sum --;
            tr[p].siz --;
            return;
        }
        if(ls*rs == 0) p = ls + rs;
        else {
            if(tr[ls].key < tr[rs].key) rturn(p);
            else lturn(p);
            del(p, x);
        }
    }
    else {
        tr[p].siz --;
        if(x > tr[p].num) del(rs , x);
        else del(ls , x);
    }

}
int find1(int &p,int x) {
    if(!p) return 0;
    int &ls = tr[p].ls ,&rs = tr[p].rs;
    if(tr[p].num == x) return tr[ls].siz + 1;
    if(tr[p].num > x) return find1(ls , x);
    if(tr[p].num < x) return tr[ls].siz + tr[p].sum + find1(rs , x);
    return 0;
}
int find2(int &p,int x) {
    if(!p) return 0;
    int &ls = tr[p].ls , &rs = tr[p].rs;
    if(tr[ls].siz + 1 <= x && tr[ls].siz + tr[p].sum >= x)return tr[p].num;
    if(tr[ls].siz >= x) return find2(ls,x);
    if(tr[ls].siz + tr[p].sum < x) return find2(rs,x - tr[ls].siz - tr[p].sum);
    return 0;
}
void pre(int &p,int x) {
    if(!p)return ;
    int &ls = tr[p].ls , &rs = tr[p].rs;
    if(tr[p].num < x) {
        t1 = tr[p].num;
        pre(rs , x);
    }
    else pre(ls , x);
}
void suc(int &p,int x) {
    if(!p)return ;
    int &ls = tr[p].ls , &rs = tr[p].rs;
    if(tr[p].num > x) {
        t2 = tr[p].num;
        suc(ls , x);
    }
    else suc(rs , x);
}
int main() {
    srand(20402);rand();
    int n=read();
    for(int i=1;i<=n;i++) {
        int opt=read(),x=read();t1=0,t2=0;
        switch(opt) {
            case 1 : insert(root , x);break;
            case 2 : del(root , x);break;
            case 3 : printf("%d\n" ,find1(root , x));break;
            case 4 : printf("%d\n" ,find2(root , x));break;
            case 5 : pre(root , x);printf("%d\n", t1);break;
            case 6 : suc(root , x);printf("%d\n", t2);break;

        }
    }
}

Splay :

#include 
#include 
#include 
#include 
#include 
using namespace std;
const int N = 110000;
#define get(p) (ch[fa[p]][1]==p)
#define ls ch[p][0]
#define rs ch[p][1]
const int inf = 1e8 ;
int root,n,opt;
class ReadIn {
    private:
    inline char nc() {
        static char buf[100000], *p1, *p2;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
    }
    public:
    inline int read() {
        int x=0;char ch=nc();
        while(!isdigit(ch))ch=nc();
        while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=nc();}
        return x;
    }
    inline char getc() {
        char ch=nc();
        while(isspace(ch))ch=nc();
        return ch;
    }
}Rd;
int ch[N][2],siz[N],fa[N],val[N],cnt[N],cnt1;
void pushup(int p) {siz[p]=siz[ls]+siz[rs]+1;}
void rotate(int p) {
    int y=fa[p],z=fa[y],k=get(p);
    ch[y][k]=ch[p][!k];fa[ch[y][k]]=y;
    ch[p][!k]=y,fa[y]=p,fa[p]=z;
    if(z) ch[z][ch[z][1]==y]=p;
    pushup(y),pushup(p);
    if(root==y)root=p;
}

void splay(int p,int tar) {
    for(int f;(f=fa[p])!=tar;rotate(p))
        if(fa[f]!=tar)
            rotate(get(p)==get(f)?f:p);
}
int creat(int x) {
    siz[++cnt1]=1;
    val[cnt1]=x;
    return cnt1;
}
void insert(int x) {
    int l=0,r=0,p = root ;
    while(p) {
        if(val[p]>=x) r=p,p=ls;
        else l=p,p=rs;
    }
    splay(l,0);splay(r,root);
    ch[r][0] = creat(x); fa[cnt1] = r;
    pushup(r);pushup(l);
    splay(cnt1,0); 
}
int pre(int x) {
    int p = root,ans=0;
    while(p) {
        if(val[p]>=x)p=ls;
        else ans=p,p=rs;
    }
    return val[ans];
}
int suc(int x) {
    int p = root,ans=0;
    while(p) {
        if(val[p]<=x)p=rs;
        else ans=p,p=ls;
    }
    return val[ans];
}
int aft(int p) {
    if(rs) {
        p=rs;while(ls) p = ls;
    }
    else {
        while(get(p)) p = fa[p];
        p=fa[p];
    }
    return p;
}
void del(int x) {
    int l=0,r=0,p = root ;
    while(p) {
        if(val[p]>=x) r=p,p=ls;
        else l=p,p=rs;
    }
    r = aft(r);
    splay(l,0);splay(r,root);
    fa[ch[r][0]]=0;ch[r][0]=0;
    pushup(r),pushup(l);
}
int find_rank(int x) {
    int p=root,ans=0;
    while(p) {
        if(val[p]>=x)p=ls;
        else ans+=siz[ls]+1,p=rs;
    }
    return ans;
}

int find_num(int x) {
    int p = root;
    while(1) {
        if(siz[ls]>=x)p=ls;
        else {
            x-=siz[ls]+1;
            if(!x)return p;
            p=rs;
        }
    }
}
int main() {
    int n,opt,x;
    scanf("%d",&n);
    root=creat(-100000000);
    ch[root][1]=creat(100000000);
    fa[ch[root][1]]=root;
    pushup(root);
    while(n--) {
        scanf("%d%d",&opt,&x);
        if(opt==1) insert(x);
        else if(opt==2) del(x);
        else if(opt==3) printf("%d\n",find_rank(x));
        else if(opt==4) printf("%d\n",val[find_num(x+1)]);
        else if(opt==5) printf("%d\n",pre(x));
        else printf("%d\n",suc(x));
    }
}

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注