알고리즘&자료구조
이진검색트리 변형 : 트립
아, 그래요?
2022. 12. 9. 13:36
#include <iostream>
#include <cstdlib>
#include <utility>
using namespace std;
typedef int KeyType;
struct Node {
KeyType key;
int priority, size;
Node* left;
Node* right;
Node(const KeyType& _key): key(_key), priority(rand()), size(1), left(NULL), right(NULL) {}
void setLeft(Node* newLeft) { left = newLeft; calcSize(); }
void setRight(Node* newRight) { right = newRight; calcSize(); }
void calcSize() {
size = 1;
if (left) size += left->size;
if (right) size += right->size;
}
};
typedef pair<Node*, Node*> NodePair;
NodePair split(Node* root, keyType key) {
if (root == NULL) return NodePair(NULL, NULL);
if (root->key < key) {
NodePair rs = split(root->right, key);
root->setRight(rs.first);
return NodePair(root, rs.second);
} else {
NodePair ls = split(root->left, key);
root->setLeft(ls.second);
return NodePair(ls.first, root);
}
}
Node* merge(Node* a, Node* b) {
if (a == NULL) return b;
if (b == NULL) return a;
if (a->priority > b->priority) {
a->setRight(merge(a->right, b));
return a;
} else {
b->setLeft(merge(a, b->left));
return b;
}
}
Node* insert(Node* root, Node* newNode) {
if (root == NULL) return newNode;
if (root->priority < newNode->priority) {
NodePair splitted = split(root, newNode->key);
newNode->setLeft(splitted.first);
newNode->setRight(splitted.second);
return newNode;
} else if (root->key < newNode->key) {
root->setRight(insert(root->right, newNode));
return root;
} else {
root->setLeft(insert(root->left, newNode));
return root;
}
}
Node* erase(Node* root, KeyType key) {
if (root == NULL) return root;
if (root->key == key) {
Node* ret = merge(root->left, root->right);
delete root;
return ret;
} else if (root->key > key) {
root->setLeft(erase(root->left, key));
} else {
root->setRight(erase(root->right, key));
}
return root;
}
Node* kth(Node* root, int k) {
int leftSize = 0;
if (root->left) leftSize = root->left->size;
if (k <= leftSize) return kth(root->left, k);
else if (k == leftSize + 1) return root;
else return kth(root->right, k-leftSize-1);
}
int countLessThan(Node* root, KeyType key) {
if (root == NULL) return 0;
if (root->key >= key) {
return countLessThan(root->left, key);
} else {
return 1 + root->left->size + countLessThan(root->right, key);
}
}