Coding Tips: Simplify your code with generalized methods

Can you simplify this code?

Here is an example of a doubly linked list in Python.

Can you simplify this code with "generalized methods"?

from dataclasses import dataclass
from typing import Any, Self


@dataclass
class Node:
    value: Any
    next: Self | None
    prev: Self | None


class DoublyLinkedList:
    def __init__(self):
        self._head: Node | None = None
        self._tail: Node | None = None
        self._size = 0

    def append(self, value):
        new = Node(value, None, None)
        if self._head is None and self._tail is None:
            self._head = new
            self._tail = new
        else:
            tail = self._tail
            tail.next = new
            new.prev = tail
            self._tail = new
        self._size += 1

    def prepend(self, value):
        new = Node(value, None, None)
        if self._head is None and self._tail is None:
            self._head = new
            self._tail = new
        else:
            head = self._head
            head.prev = new
            new.next = head
            self._head = new
        self._size += 1

    def insert(self, index, value):
        if index > self._size or index < 0:  # zero based position.
            raise IndexError

        # head or single node case.
        if self._size == 1 or index == 0:
            self.prepend(value)  # handle adding as a head.
        # multiple node and not-head case.
        elif index == self._size:
            self.append(value)
        else:
            new = Node(value, None, None)
            # if index is tail, we add new node to before the current tail, so we don't need to update the tail pointer.
            node = self._head
            current_idx = 0
            while node:
                if current_idx == index:
                    break
                node = node.next
                current_idx += 1

            node.prev.next = new
            new.prev = node.prev

            node.prev = new
            new.next = node

            self._size += 1

    def pop_first(self):
        if self._size == 0:
            raise IndexError
        poped = self._head
        if self._size == 1:
            self._head = None
            self._tail = None
        else:
            new_head = poped.next
            new_head.prev = None
            self._head = new_head

        self._size -= 1
        return poped.value

    def pop_last(self):
        if self._size == 0:
            raise IndexError
        poped = self._tail
        if self._size == 1:
            self._head = None
            self._tail = None
        else:
            new_tail = poped.prev
            new_tail.next = None
            self._tail = new_tail

        self._size -= 1
        return poped.value

You can generalize "insert"

Inserting a new node into the doubly linked list. In the current implementation, the insert operations live in the append, prepend, and insert methods. For append, a new node is inserted at the tail. For prepend, at the head. For insert, anywhere. With a doubly linked list, you need to be careful when updating the head and tail pointers. If there is only one node and you insert a new node at the head, you also need to update the head pointer. If there are no nodes, you need to update both the head and tail pointers with the new node.

How do we generalize this method?

We want to insert a new node anywhere. Prepend and append are special cases for adding to the head and tail. In a doubly linked list, a node must have prev and next pointers (including None for the head or tail case). So the operation is just updating the pointers between the previous and next node. We can generalize it like this:

def _insert_between(self, prev: Node | None, new: Node, next: Node | None):
    if prev:
        prev.next = new
        new.prev = prev
    else:
        self._head = new
    if next:
        next.prev = new
        new.next = next
    else:
      self._tail = new
    self._size += 1

We can simplify the insert methods with _insert_between like this:

def append(self, value):
    new = Node(value, None, None)
    self._insert_between(self._tail, new, None)

def prepend(self, value):
    new = Node(value, None, None)
    self._insert_between(None, new, self._head)

def insert(self, index, value):
    new = Node(value, None, None)

    current_idx = 0
    prev = None
    next = self._head
    while current_idx < self._size:
        if current_idx == index:
            break
        current_idx += 1
        prev = next
        next = next.next
    self._insert_between(prev, new, next)

The code volume is much smaller, and updating _size is also hidden inside _insert_between, so we no longer need to handle it everywhere like the previous version.

You can generalize "remove"

The pop_first and pop_last methods share the same remove operation for a doubly linked list.

def _remove(self, node: Node | None) -> Node:
    if node is None:
        raise IndexError
    if node.prev:
        node.prev.next = node.next
    else:
        self._head = node.next
    if node.next:
        node.next.prev = node.prev
    else:
        self._tail = node.prev
    self._size -= 1
    return node

Final code

"Insert" and "remove" operations are generalized.

from dataclasses import dataclass
from typing import Any, Self


@dataclass
class Node:
    value: Any
    next: Self | None
    prev: Self | None


class DoublyLinkedList:
    def __init__(self):
        self._head: Node | None = None
        self._tail: Node | None = None
        self._size = 0

    def _insert_between(self, prev: Node | None, new_node: Node, next: Node | None):
        if prev:
            prev.next = new_node
            new_node.prev = prev
        else:
            self._head = new_node
        if next:
            next.prev = new_node
            new_node.next = next
        else:
            self._tail = new_node
        self._size += 1

    def append(self, value):
        new = Node(value, None, None)
        self._insert_between(self._tail, new, None)

    def prepend(self, value):
        new = Node(value, None, None)
        self._insert_between(None, new, self._head)

    def insert(self, index, value):
        if index > self._size or index < 0:  # zero based position.
            raise IndexError
        new = Node(value, None, None)

        current_idx = 0
        prev = None
        next = self._head
        while current_idx < self._size:
            if current_idx == index:
                break
            current_idx += 1
            prev = next
            next = next.next
        self._insert_between(prev, new, next)

    def _remove(self, node: Node):
        if node is None:
            raise IndexError

        if node.prev:
            node.prev.next = node.next
        else:
            self._head = node.next
        if node.next:
            node.next.prev = node.prev
        else:
            self._tail = node.prev
        self._size -= 1
        return node

    def pop_first(self):
        if self._size == 0:
            raise IndexError
        poped = self._remove(self._head)
        return poped.value

    def pop_last(self):
        if self._size == 0:
            raise IndexError
        poped = self._remove(self._tail)
        return poped.value

Conclusion

In this article, we extract generalized methods for insert and remove operations in a doubly linked list. These generalized methods are the basic and abstracted operations for the data structure. They can be useful in many situations, reduce duplicated code, and lead to fewer bugs.