Algorithm: String to Integer (atoi) optimization: if-else vs state machine

Quiz

What's the space complexity of the following code?

import string


class Solution:
    def myAtoi(self, s: str) -> int:
        def consume_whitespaces(pos: int) -> int:
            while pos < len(s) and s[pos] in string.whitespace:
                pos += 1
            return pos

        def consume_sign(pos: int) -> tuple[int, int]:
            if pos >= len(s):
                return pos, 1
            if s[pos] == "-":
                pos += 1
                return pos, -1
            elif s[pos] == "+":
                pos += 1
                return pos, 1
            else:
                return pos, 1

        def consume_zeros(pos: int) -> int:
            while pos < len(s) and s[pos] == "0":
                pos += 1
            return pos

        def consume_number(pos: int) -> int:  # position & number
            num = 0
            if pos >= len(s):
                return num

            if s[pos] == "0":
                raise ValueError(
                    'Wrong position to start consuming number, but got "0".'
                )

            start = pos
            while pos < len(s) and s[pos] in string.digits:
                pos += 1
            number_str = s[start:pos]
            if number_str == "":
                return 0
            else:
                return int(number_str)

        def round_number(num: int) -> int:
            if num > 0:
                return min(num, 2**31 - 1)
            else:
                return max(num, -(2**31))

        pos = 0
        pos = consume_whitespaces(pos)
        pos, sign = consume_sign(pos)
        pos = consume_zeros(pos)
        number = consume_number(
            pos
        )  # Expect that this must reach to the end (pos = len(s))
        number_int = round_number(sign * number)

        return number_int

Answer

O(S) where S is the length of the input. When slicing the string, it creates a copy of the slice, which needs O(S) space.

number_str = s[start:pos]

We can optimize this by consuming digits as we scan:

def consume_number(pos: int) -> int:
    num = 0
    if pos >= len(s):  # we must check this first.
        return num

    if s[pos] == "0":
        raise ValueError(
            'Wrong position to start consuming number, but got "0".'
        )

    while pos < len(s) and s[pos] in string.digits:  # check pos < len(s) first.
        num = num * 10 + int(s[pos])
        pos += 1
    return num

Now it has O(1) space complexity.

State machine version

When we convert ASCII to an integer, there is a state flow:

  1. nothing read (consuming space)
  2. signed (consuming + or -)
  3. reading number
  4. end

Why state machine?

  • readability: we can reduce if/else
  • fewer bugs: we ignore other patterns; other states handle them
  • extensibility

The state defines acceptable input and the state transitions. That's it. If a value is not acceptable, we can immediately stop processing.

The above code calls methods and updates position like:

pos = 0
pos = consume_whitespaces(pos)
pos, sign = consume_sign(pos)
pos = consume_zeros(pos)
number = consume_number(
    pos
)  # Expect that this must reach to the end (pos = len(s))
number_int = round_number(sign * number)

This is OK, but if we want to add a new step, where should it go? It lowers extensibility and readability. And at each step, we need to verify the position is valid. We also raise exceptions in some places when we get unexpected values, so the caller needs to handle them.

Final code

class Solution:
    def myAtoi(self, s: str) -> int:
        START, SIGNED, NUMBER, END = range(4)
        state = START
        sign = 1
        num = 0

        for c in s:
            if state == START:
                if c == " ":
                    continue
                elif c in "+-":
                    sign = -1 if c == "-" else 1
                    state = SIGNED
                elif c.isdigit():
                    num = int(c)
                    state = NUMBER
                else:
                    break

            elif state == SIGNED:
                if c.isdigit():
                    num = int(c)
                    state = NUMBER
                else:
                    break

            elif state == NUMBER:
                if c.isdigit():
                    num = num * 10 + int(c)
                else:
                    break

        num *= sign
        return max(min(num, 2**31 - 1), -(2**31))

The logic is much simpler. The state flow is START -> SIGNED -> NUMBER -> END. In each state, we handle only the current input: if valid, shift state or process it; otherwise, stop. We no longer need to track the current position explicitly.