Binary Search Trees: Implementation, Balancing, and When to Use Them

A binary search tree (BST) stores elements such that for every node, all values in the left subtree are smaller and all values in the right subtree are larger. This ordering property enables O(log n)

Introduction#

A binary search tree (BST) stores elements such that for every node, all values in the left subtree are smaller and all values in the right subtree are larger. This ordering property enables O(log n) search, insert, and delete — but only when the tree remains balanced. Understanding BSTs, their degeneration, and self-balancing variants is foundational knowledge for system design and algorithm interviews.

Core BST Implementation#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations
from typing import Optional, Iterator

class BSTNode:
    __slots__ = ("key", "value", "left", "right")

    def __init__(self, key: int, value: object):
        self.key = key
        self.value = value
        self.left: Optional[BSTNode] = None
        self.right: Optional[BSTNode] = None


class BST:
    def __init__(self):
        self._root: Optional[BSTNode] = None
        self._size = 0

    def insert(self, key: int, value: object = None) -> None:
        self._root = self._insert(self._root, key, value)

    def _insert(self, node: Optional[BSTNode], key: int, value: object) -> BSTNode:
        if node is None:
            self._size += 1
            return BSTNode(key, value)
        if key < node.key:
            node.left = self._insert(node.left, key, value)
        elif key > node.key:
            node.right = self._insert(node.right, key, value)
        else:
            node.value = value  # update existing
        return node

    def search(self, key: int) -> Optional[object]:
        node = self._root
        while node:
            if key == node.key:
                return node.value
            node = node.left if key < node.key else node.right
        return None

    def delete(self, key: int) -> None:
        self._root, deleted = self._delete(self._root, key)
        if deleted:
            self._size -= 1

    def _delete(self, node: Optional[BSTNode], key: int) -> tuple[Optional[BSTNode], bool]:
        if node is None:
            return None, False

        deleted = False
        if key < node.key:
            node.left, deleted = self._delete(node.left, key)
        elif key > node.key:
            node.right, deleted = self._delete(node.right, key)
        else:
            deleted = True
            if node.left is None:
                return node.right, deleted
            if node.right is None:
                return node.left, deleted
            # Two children: replace with in-order successor (smallest in right subtree)
            successor = self._min_node(node.right)
            node.key, node.value = successor.key, successor.value
            node.right, _ = self._delete(node.right, successor.key)

        return node, deleted

    def _min_node(self, node: BSTNode) -> BSTNode:
        while node.left:
            node = node.left
        return node

    def inorder(self) -> Iterator[tuple[int, object]]:
        """In-order traversal yields keys in sorted order."""
        yield from self._inorder(self._root)

    def _inorder(self, node: Optional[BSTNode]) -> Iterator[tuple[int, object]]:
        if node:
            yield from self._inorder(node.left)
            yield node.key, node.value
            yield from self._inorder(node.right)

    def __len__(self) -> int:
        return self._size

Tree Degeneration#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Inserting sorted data creates a linked list (O(n) operations)
bst = BST()
for i in range(1, 8):
    bst.insert(i)

# Resulting tree (degenerate):
# 1
#  \
#   2
#    \
#     3
#      \
#       4 ...
# Height = n, search is O(n)

# Balanced insertion (random order):
bst2 = BST()
for k in [4, 2, 6, 1, 3, 5, 7]:
    bst2.insert(k)

# Resulting tree:
#     4
#    / \
#   2   6
#  / \ / \
# 1  3 5  7
# Height = log(n), search is O(log n)

AVL Tree: Height-Balanced BST#

An AVL tree maintains the invariant that the heights of left and right subtrees differ by at most 1.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class AVLNode:
    __slots__ = ("key", "value", "left", "right", "height")

    def __init__(self, key: int, value: object):
        self.key = key
        self.value = value
        self.left: Optional[AVLNode] = None
        self.right: Optional[AVLNode] = None
        self.height = 1


class AVLTree:
    def __init__(self):
        self._root: Optional[AVLNode] = None

    def _height(self, node: Optional[AVLNode]) -> int:
        return node.height if node else 0

    def _update_height(self, node: AVLNode) -> None:
        node.height = 1 + max(self._height(node.left), self._height(node.right))

    def _balance_factor(self, node: AVLNode) -> int:
        return self._height(node.left) - self._height(node.right)

    def _rotate_right(self, y: AVLNode) -> AVLNode:
        x = y.left
        T2 = x.right
        x.right = y
        y.left = T2
        self._update_height(y)
        self._update_height(x)
        return x

    def _rotate_left(self, x: AVLNode) -> AVLNode:
        y = x.right
        T2 = y.left
        y.left = x
        x.right = T2
        self._update_height(x)
        self._update_height(y)
        return y

    def _rebalance(self, node: AVLNode, key: int) -> AVLNode:
        self._update_height(node)
        bf = self._balance_factor(node)

        # Left heavy
        if bf > 1:
            if key > node.left.key:  # Left-Right case
                node.left = self._rotate_left(node.left)
            return self._rotate_right(node)

        # Right heavy
        if bf < -1:
            if key < node.right.key:  # Right-Left case
                node.right = self._rotate_right(node.right)
            return self._rotate_left(node)

        return node

    def insert(self, key: int, value: object = None) -> None:
        self._root = self._insert(self._root, key, value)

    def _insert(self, node: Optional[AVLNode], key: int, value: object) -> AVLNode:
        if node is None:
            return AVLNode(key, value)
        if key < node.key:
            node.left = self._insert(node.left, key, value)
        elif key > node.key:
            node.right = self._insert(node.right, key, value)
        else:
            node.value = value
            return node

        return self._rebalance(node, key)

BST Operations Complexity#

1
2
3
4
5
6
7
8
9
10
11
12
13
Operation    | BST (avg) | BST (worst) | AVL Tree | Red-Black Tree
-------------|-----------|-------------|----------|---------------
Search       | O(log n)  | O(n)        | O(log n) | O(log n)
Insert       | O(log n)  | O(n)        | O(log n) | O(log n)
Delete       | O(log n)  | O(n)        | O(log n) | O(log n)
Min/Max      | O(log n)  | O(n)        | O(log n) | O(log n)
In-order     | O(n)      | O(n)        | O(n)     | O(n)
Space        | O(n)      | O(n)        | O(n)     | O(n)

BST worst case: sorted insertion (linked list)
AVL and Red-Black: guaranteed O(log n) via rebalancing
Red-Black: fewer rotations than AVL, better for insert-heavy workloads
AVL: stricter balance, better for read-heavy workloads

Range Queries#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def range_query(
    root: Optional[BSTNode],
    low: int,
    high: int,
) -> list[tuple[int, object]]:
    """Return all (key, value) pairs where low <= key <= high."""
    result = []

    def traverse(node: Optional[BSTNode]) -> None:
        if node is None:
            return
        if low < node.key:   # there might be results in left subtree
            traverse(node.left)
        if low <= node.key <= high:
            result.append((node.key, node.value))
        if node.key < high:  # there might be results in right subtree
            traverse(node.right)

    traverse(root)
    return result

# Example
bst = BST()
for k in [10, 5, 15, 3, 7, 12, 20]:
    bst.insert(k, f"val_{k}")

print(range_query(bst._root, 6, 14))
# [(7, 'val_7'), (10, 'val_10'), (12, 'val_12')]

Floor and Ceiling#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def floor(root: Optional[BSTNode], key: int) -> Optional[int]:
    """Largest key <= given key."""
    result = None
    node = root
    while node:
        if key == node.key:
            return node.key
        elif key < node.key:
            node = node.left
        else:
            result = node.key
            node = node.right
    return result

def ceiling(root: Optional[BSTNode], key: int) -> Optional[int]:
    """Smallest key >= given key."""
    result = None
    node = root
    while node:
        if key == node.key:
            return node.key
        elif key > node.key:
            node = node.right
        else:
            result = node.key
            node = node.left
    return result

When BSTs Are Used in Practice#

1
2
3
4
5
6
7
8
9
10
Python sortedcontainers.SortedList: balanced BST under the hood
Java TreeMap / TreeSet: Red-Black tree
C++ std::map / std::set: Red-Black tree
PostgreSQL B-tree index: multi-way BST variant for disk storage
In-memory symbol tables, priority schedulers

When NOT to use BSTs:
- Need O(1) lookup: use hash table
- Range queries on disk: use B-tree (better cache locality)
- Sorted array with no inserts: binary search is simpler

Conclusion#

BSTs provide efficient ordered operations — search, insert, delete, range queries, floor/ceiling — in O(log n) when balanced. Unbalanced BSTs degenerate to O(n). Use AVL trees for read-heavy workloads requiring strict balance, and Red-Black trees for insert-heavy workloads. In most application code, reach for the standard library’s balanced BST implementations (sortedcontainers in Python, TreeMap in Java) rather than implementing your own. Understanding BSTs provides the foundation for B-trees, segment trees, and other ordered data structures.

Contents