알고리즘&자료구조

이진검색트리 변형 : 트립

아, 그래요? 2022. 12. 9. 13:36
#include <iostream>
#include <cstdlib>
#include <utility>

using namespace std;

typedef int KeyType;

struct Node {
    KeyType key;
    int priority, size;
    
    Node* left;
    Node* right;
    
    Node(const KeyType& _key): key(_key), priority(rand()), size(1), left(NULL), right(NULL) {}
    
    void setLeft(Node* newLeft) { left = newLeft; calcSize(); }
    void setRight(Node* newRight) { right = newRight; calcSize(); }
    
    void calcSize() {
        size = 1;
        if (left) size += left->size;
        if (right) size += right->size;
    }
};

typedef pair<Node*, Node*> NodePair;

NodePair split(Node* root, keyType key) {
    if (root == NULL) return NodePair(NULL, NULL);
    
    if (root->key < key) {
        NodePair rs = split(root->right, key);
        root->setRight(rs.first);
        return NodePair(root, rs.second);
    } else {
        NodePair ls = split(root->left, key);
        root->setLeft(ls.second);
        return NodePair(ls.first, root);
    }
}

Node* merge(Node* a, Node* b) {
    if (a == NULL) return b;
    if (b == NULL) return a;
    
    if (a->priority > b->priority) {
        a->setRight(merge(a->right, b));
        return a;
    } else {
        b->setLeft(merge(a, b->left));
        return b;
    }
}

Node* insert(Node* root, Node* newNode) {
    if (root == NULL) return newNode;
    
    if (root->priority < newNode->priority) {
        NodePair splitted = split(root, newNode->key);
        newNode->setLeft(splitted.first);
        newNode->setRight(splitted.second);
        return newNode;
    } else if (root->key < newNode->key) {
        root->setRight(insert(root->right, newNode));
        return root;
    } else {
        root->setLeft(insert(root->left, newNode));
        return root;
    }
}

Node* erase(Node* root, KeyType key) {
    if (root == NULL) return root;
    
    if (root->key == key) {
        Node* ret = merge(root->left, root->right);
        delete root;
        return ret;
    } else if (root->key > key) {
        root->setLeft(erase(root->left, key));
    } else {
        root->setRight(erase(root->right, key));
    }
    return root;
    
}

Node* kth(Node* root, int k) {
    int leftSize = 0;
    if (root->left) leftSize = root->left->size;
    
    if (k <= leftSize) return kth(root->left, k);
    else if (k == leftSize + 1) return root;
    else return kth(root->right, k-leftSize-1);
}

int countLessThan(Node* root, KeyType key) {
    if (root == NULL) return 0;
    
    if (root->key >= key) {
        return countLessThan(root->left, key);
    } else {
        return 1 + root->left->size + countLessThan(root->right, key);
    }
}