Elixir Recursion with Cheatsheet

In functional languages like Elixir, recursion replaces loops. Understanding recursion isn’t just academic—it’s essential for thinking functionally and writing efficient, elegant code.

Why Recursion Matters in Elixir

Elixir has no traditional loops. No for, while, or do-while. Instead, you use:

  • Recursion for custom iteration
  • Enum module for common patterns
  • Stream module for lazy evaluation

Mastering recursion unlocks the full power of functional programming.

Basic Recursion Patterns

The Foundation: Base Case + Recursive Case

defmodule BasicRecursion do
  # Classic factorial
  def factorial(0), do: 1                    # Base case
  def factorial(n) when n > 0 do             # Recursive case
    n * factorial(n - 1)
  end
  
  # List length
  def length([]), do: 0                      # Base case: empty list
  def length([_head | tail]) do              # Recursive case
    1 + length(tail)
  end
  
  # Sum of list
  def sum([]), do: 0
  def sum([head | tail]), do: head + sum(tail)
end

Every recursive function follows this pattern:

  1. Base case: When to stop recursing
  2. Recursive case: How to break down the problem
  3. Progress: Each call must move toward the base case

Head vs Tail Recursion

This is crucial for performance in Elixir.

Head Recursion (Can Cause Stack Overflow)

# ❌ Head recursion - builds up stack
def sum([]), do: 0
def sum([head | tail]) do
  head + sum(tail)  # Recursive call, then addition
end

# Call stack grows: 
# sum([1,2,3]) -> 1 + sum([2,3]) -> 1 + (2 + sum([3])) -> 1 + (2 + (3 + 0))

Tail Recursion (Stack Safe)

# ✅ Tail recursion with accumulator
def sum(list), do: sum(list, 0)

defp sum([], acc), do: acc                   # Base case: return accumulator
defp sum([head | tail], acc) do              # Recursive case
  sum(tail, acc + head)                      # Tail call - no stack buildup
end

# Call sequence:
# sum([1,2,3], 0) -> sum([2,3], 1) -> sum([3], 3) -> sum([], 6) -> 6

The tail call optimization in the BEAM VM converts tail recursion into efficient loops.

Essential Recursion Patterns

1. Accumulator Pattern

The most important pattern for tail recursion:

defmodule AccumulatorExamples do
  # Reverse a list
  def reverse(list), do: reverse(list, [])
  defp reverse([], acc), do: acc
  defp reverse([head | tail], acc), do: reverse(tail, [head | acc])
  
  # Filter list
  def filter(list, fun), do: filter(list, fun, [])
  defp filter([], _fun, acc), do: Enum.reverse(acc)
  defp filter([head | tail], fun, acc) do
    if fun.(head) do
      filter(tail, fun, [head | acc])
    else
      filter(tail, fun, acc)
    end
  end
  
  # Map over list
  def map(list, fun), do: map(list, fun, [])
  defp map([], _fun, acc), do: Enum.reverse(acc)
  defp map([head | tail], fun, acc) do
    map(tail, fun, [fun.(head) | acc])
  end
end

2. Tree Recursion

For processing nested structures:

defmodule TreeRecursion do
  # Binary tree operations
  defmodule Tree do
    defstruct [:value, :left, :right]
  end
  
  # Tree depth
  def depth(nil), do: 0
  def depth(%Tree{left: left, right: right}) do
    1 + max(depth(left), depth(right))
  end
  
  # Tree sum
  def sum(nil), do: 0
  def sum(%Tree{value: value, left: left, right: right}) do
    value + sum(left) + sum(right)
  end
  
  # Find in tree
  def find(nil, _target), do: false
  def find(%Tree{value: target}, target), do: true
  def find(%Tree{left: left, right: right}, target) do
    find(left, target) or find(right, target)
  end
end

3. Nested Data Recursion

For processing complex data structures:

defmodule NestedRecursion do
  # Deep list flattening
  def flatten(list), do: flatten(list, [])
  
  defp flatten([], acc), do: Enum.reverse(acc)
  defp flatten([head | tail], acc) when is_list(head) do
    flattened_head = flatten(head, [])
    flatten(tail, Enum.reverse(flattened_head) ++ acc)
  end
  defp flatten([head | tail], acc) do
    flatten(tail, [head | acc])
  end
  
  # Deep map key transformation
  def transform_keys(%{} = map, fun) do
    Map.new(map, fn {key, value} ->
      new_key = fun.(key)
      new_value = 
        case value do
          %{} = nested_map -> transform_keys(nested_map, fun)
          other -> other
        end
      {new_key, new_value}
    end)
  end
  def transform_keys(value, _fun), do: value
end

Advanced Recursion Techniques

Mutual Recursion

Functions calling each other:

defmodule MutualRecursion do
  # Even/odd checker
  def is_even(0), do: true
  def is_even(n) when n > 0, do: is_odd(n - 1)
  
  def is_odd(0), do: false
  def is_odd(n) when n > 0, do: is_even(n - 1)
  
  # Expression evaluator
  def evaluate({:number, n}), do: n
  def evaluate({:add, left, right}) do
    evaluate(left) + evaluate(right)
  end
  def evaluate({:multiply, left, right}) do
    evaluate(left) * evaluate(right)
  end
  def evaluate({:negate, expr}) do
    -evaluate(expr)
  end
end

Continuation-Passing Style (CPS)

For complex control flow:

defmodule CPS do
  # CPS factorial - passes computation forward
  def factorial_cps(n), do: factorial_cps(n, &(&1))
  
  defp factorial_cps(0, cont), do: cont.(1)
  defp factorial_cps(n, cont) when n > 0 do
    factorial_cps(n - 1, fn result -> cont.(n * result) end)
  end
  
  # CPS list processing
  def map_cps(list, fun), do: map_cps(list, fun, &(&1))
  
  defp map_cps([], _fun, cont), do: cont.([])
  defp map_cps([head | tail], fun, cont) do
    map_cps(tail, fun, fn result ->
      cont.([fun.(head) | result])
    end)
  end
end

Recursion Cheatsheet

Basic Patterns

# Base + Recursive case
def func([]), do: base_case
def func([h | t]), do: combine(h, func(t))

# Tail recursion with accumulator  
def func(list), do: func_helper(list, initial_acc)
defp func_helper([], acc), do: acc
defp func_helper([h | t], acc), do: func_helper(t, update_acc(h, acc))

# Counting down
def countdown(0), do: :done
def countdown(n) when n > 0 do
  IO.puts(n)
  countdown(n - 1)
end

# Processing pairs
def process_pairs([]), do: []
def process_pairs([a, b | rest]), do: [process(a, b) | process_pairs(rest)]
def process_pairs([single]), do: [single]  # Odd number of elements

Tree/Nested Patterns

# Binary tree
def tree_func(nil), do: base_case
def tree_func(%{left: l, right: r, value: v}) do
  combine(v, tree_func(l), tree_func(r))
end

# Nested maps/lists
def deep_process(%{} = map) do
  Map.new(map, fn {k, v} -> {k, deep_process(v)} end)
end
def deep_process([]), do: []
def deep_process([h | t]), do: [deep_process(h) | deep_process(t)]
def deep_process(value), do: transform(value)

Performance Patterns

# ✅ Tail recursive with accumulator
def tail_recursive(list), do: helper(list, [])
defp helper([], acc), do: Enum.reverse(acc)  # If order matters
defp helper([h | t], acc), do: helper(t, [process(h) | acc])

# ✅ Early termination
def find_first([], _pred), do: nil
def find_first([h | t], pred) do
  if pred.(h), do: h, else: find_first(t, pred)
end

# ✅ Multiple accumulators
def stats(list), do: stats_helper(list, 0, 0, 0)  # count, sum, max
defp stats_helper([], count, sum, max), do: {count, sum, max}
defp stats_helper([h | t], count, sum, max) do
  stats_helper(t, count + 1, sum + h, max(h, max))
end

Real-World Example: JSON Parser

defmodule SimpleJsonParser do
  # Parse JSON-like structures recursively
  def parse(input) when is_binary(input) do
    input
    |> String.trim()
    |> parse_value()
  end
  
  # Parse different JSON values
  defp parse_value("{" <> rest), do: parse_object(rest, %{})
  defp parse_value("[" <> rest), do: parse_array(rest, [])
  defp parse_value("\"" <> rest), do: parse_string(rest, "")
  defp parse_value("true" <> rest), do: {true, rest}
  defp parse_value("false" <> rest), do: {false, rest}
  defp parse_value("null" <> rest), do: {nil, rest}
  defp parse_value(input) do
    case Integer.parse(input) do
      {int, rest} -> {int, rest}
      :error -> {:error, "Invalid JSON"}
    end
  end
  
  # Parse JSON object recursively
  defp parse_object("}" <> rest, acc), do: {acc, rest}
  defp parse_object(input, acc) do
    with {key, rest1} <- parse_string(String.trim_leading(input), ""),
         ":" <> rest2 <- String.trim_leading(rest1),
         {value, rest3} <- parse_value(String.trim_leading(rest2)) do
      new_acc = Map.put(acc, key, value)
      
      case String.trim_leading(rest3) do
        "}" <> rest4 -> {new_acc, rest4}
        "," <> rest4 -> parse_object(rest4, new_acc)
        _ -> {:error, "Expected ',' or '}'"}
      end
    end
  end
  
  # Parse JSON array recursively  
  defp parse_array("]" <> rest, acc), do: {Enum.reverse(acc), rest}
  defp parse_array(input, acc) do
    with {value, rest1} <- parse_value(String.trim_leading(input)) do
      new_acc = [value | acc]
      
      case String.trim_leading(rest1) do
        "]" <> rest2 -> {Enum.reverse(new_acc), rest2}
        "," <> rest2 -> parse_array(rest2, new_acc)
        _ -> {:error, "Expected ',' or ']'"}
      end
    end
  end
  
  # Parse string with escape handling
  defp parse_string("\"" <> rest, acc), do: {acc, rest}
  defp parse_string(<<char, rest::binary>>, acc) do
    parse_string(rest, acc <> <<char>>)
  end
  defp parse_string("", _acc), do: {:error, "Unterminated string"}
end

Performance Guidelines

Memory Usage

# ❌ Stack-hungry head recursion
def bad_sum([]), do: 0
def bad_sum([h | t]), do: h + bad_sum(t)  # Builds up call stack

# ✅ Tail recursive (constant stack)
def good_sum(list), do: sum_acc(list, 0)
defp sum_acc([], acc), do: acc
defp sum_acc([h | t], acc), do: sum_acc(t, h + acc)

When to Use Enum Instead

# ✅ Use Enum for common patterns
Enum.map(list, &process/1)           # Instead of custom map recursion
Enum.reduce(list, acc, &combine/2)   # Instead of custom accumulator
Enum.filter(list, &predicate/1)      # Instead of custom filter

# ✅ Use recursion for custom logic
def custom_grouper([], _size, groups), do: Enum.reverse(groups)
def custom_grouper(list, size, groups) do
  {chunk, rest} = Enum.split(list, size)
  custom_grouper(rest, size, [chunk | groups])
end

Debugging Recursion

Add Tracing

def factorial(n), do: factorial(n, 1, 0)

defp factorial(0, acc, depth) do
  IO.puts("#{String.duplicate("  ", depth)}Base case: #{acc}")
  acc
end
defp factorial(n, acc, depth) when n > 0 do
  IO.puts("#{String.duplicate("  ", depth)}factorial(#{n}, #{acc})")
  factorial(n - 1, n * acc, depth + 1)
end

Common Pitfalls

# ❌ Forgetting base case
def infinite_loop([_h | t]), do: infinite_loop(t)  # Missing [] case!

# ❌ Not making progress
def stuck(n) when n > 0, do: stuck(n)  # Never decreases n!

# ❌ Wrong pattern order
def wrong_order(n) when n > 0, do: n * wrong_order(n - 1)
def wrong_order(0), do: 1  # This will never match!

Next Steps

Master these recursion patterns and you’ll be able to:

  • Implement any algorithm functionally
  • Process nested data structures elegantly
  • Write memory-efficient code
  • Think in recursive terms naturally

Recursion is the foundation of functional programming—once it clicks, everything else becomes simpler.


Coming next: Piping basics and how to chain operations elegantly in Elixir.