Algorithm: Splitting a Binary Search Tree by Value — A Pointer-Based Approach

Problem

Given a Binary Search Tree (BST) and a target value T, split the tree into two BSTs:

  • Left tree: all nodes with val <= T
  • Right tree: all nodes with val > T

Constraints:

  • Do not create new nodes
  • Preserve BST structure as much as possible
  • Each original node must belong to exactly one tree

Key Insight

This problem is not about deleting nodes.

It is about rewiring pointers using the BST property:

For any node x:

  • All nodes in x.left are < x.val
  • All nodes in x.right are > x.val

So at every node, one subtree is guaranteed safe, and only the other subtree may need further splitting.


A Simple Example

Consider this BST:

        4
       / \
      2   6
     / \ / \
    1  3 5  7

Target: T = 2

Expected result:

Left (<= 2):           Right (> 2):

        2                     4
       /                       / \
      1                       3   6
                                 / \
                                5   7

Case 1: node.val <= T

Consider a node that satisfies this case:

node = 2, T = 2
        2
       / \
      1   3

Observations

  • node.left All values are <= node.val <= Tentirely safe
  • node.right May contain both <= T and > Tneeds splitting

Strategy

Split only the right subtree:

split(node.right)

This splits into:

        3

Into:

<= T :   (empty)
>  T :    3

Then rewire pointers:

node.right = left_part   # <= T
return (node, right_part)

Result:

Left tree root  = 2
Right tree root = 3

Case 2: node.val > T

Now imagine recursion reaches this node (the root in our example):

node = 4, T = 2
        4
       / \
      2   6

Observations

  • node.right All values are > node.val > Tentirely safe
  • node.left May contain <= Tneeds splitting

Strategy

Split only the left subtree:

split(node.left)

Then rewire:

node.left = right_part
return (left_part, node)

The Core Invariant

For every recursive call:

split(node, T) returns (A, B)

Where:

  • A contains only nodes <= T
  • B contains only nodes > T
  • Every node appears exactly once
  • No new nodes are created

Why Nothing “Moves Too Far”

A common concern:

“Won’t nodes with val == T accidentally move to the right tree?”

No.

Because:

  • val == T always satisfies <= T
  • The branching condition ensures such nodes never enter the > T branch
  • Only the ambiguous subtree is split; the safe subtree is untouched

This makes the algorithm both correct and minimal.

Final Algorithm (Code)

from typing import Optional, List

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right


class Solution:
    def splitBST(
        self, root: Optional["TreeNode"], target: int
    ) -> List[Optional["TreeNode"]]:
        """
        Returns [<= target tree, > target tree].
        Rewires pointers in-place (no new nodes).
        Time: O(N), Space: O(H) recursion stack (H = height).
        """
        left_tree, right_tree = self._split(root, target)
        return [left_tree, right_tree]

    def _split(
        self, node: Optional["TreeNode"], target: int
    ) -> tuple[Optional["TreeNode"], Optional["TreeNode"]]:
        if node is None:
            return (None, None)

        if node.val <= target:
            # node belongs to the left (<= target) tree.
            # Split the right subtree because it may contain both sides.
            left_part, right_part = self._split(node.right, target)

            # Rewire: keep the <= part as node.right
            node.right = left_part

            # node is the root of the left tree, right_part is the root of the right tree
            return (node, right_part)

        else:
            # node belongs to the right (> target) tree.
            # Split the left subtree because it may contain both sides.
            left_part, right_part = self._split(node.left, target)

            # Rewire: keep the > part as node.left
            node.left = right_part

            # left_part is the root of the left tree, node is the root of the right tree
            return (left_part, node)

Complexity

  • Time: O(N) — each node is visited once

  • Space: O(H) — recursion stack

    • Balanced BST: O(log N)
    • Worst-case (skewed BST): O(N)

Takeaway

This problem is a great example of:

  • Using BST properties instead of deletion logic
  • Solving structural problems with local pointer rewiring
  • Designing algorithms around invariants, not mutations

Once you see it, the solution is surprisingly simple — but getting there requires the right mental model.