Home

Predicting text: from bigram bag of words to transformer

This is my write up of what I learned following Andrej Karpathy's excellent Makemore series on learning to build large language models. I take tangents in order to explore some aspects in more detail, but I don't provide full explanations for everything, nor the full code, which can be found here. I wrote my code with Jax rather than Pytorch, and show snippets below. As we progress through the article we build bigger and better models for text completion.

Bigram bag of words

A bag of words model is an extremely simple model which assigns the probability of a token to the fraction of tokens in the training set which are that token. A bigram bag of words goes a small step further by predicting pairs of characters based on the proportion of pairs for which token A is followed by token B. Consider a character-level (tokens are single characters) bigram bag of words model trained on the following text, focussing on the letter o.

The quick brown fox jumped over the lazy dog.

The letter o occurs three times and is followed by x, v, and g. A bag of words model would assign next character probabilities of $P(x \mid o) = P(v \mid o) = P(g \mid o) = 1/3$ and 0 otherwise.

bigram_to_count: dict[tuple[str, str], int] = {}
# words is list of common names: https://github.com/karpathy/makemore/blob/master/names.txt
for word in words:
    # Designate . as a special start/end character.
    word = ["."] + list(word) + ["."]
    for char1, char2 in zip(word, word[1:]):
        bigram = (char1, char2)
        bigram_to_count[bigram] = bigram_to_count.get(bigram, 0) + 1

import numpy as np
bigram_matrix = np.zeros((27, 27), dtype=int)
for (char1, char2), count in bigram_to_count.items():
    i, j = char_to_index[char1], char_to_index[char2]
    bigram_matrix[i, j] = count

Maximum likelihood estimation

Assuming that bigrams are independent and identically distributed random variables, the probability distribution with the maximum likelihood for the training data is the bigram bag of words model. To show this, we will find the probability distribution $p$ such that

$$\hat{p} = \arg\max_{p} ; P(\text{training data} \mid p)$$

by maximising the likelihood

$$L(p) = P((x_1,x_2), \dots, (x_{N-1},x_N) \mid p) = \prod_{i=1}^{N-1} p(x_i,x_{i+1})$$

$$= \prod_{(u,v) \in \mathcal{B}} p(u,v)^{c_{uv}}$$

where in the previous line bigrams $(u,v)$ in the bigram set $\mathcal{B}$ are grouped together, e.g.

$$P(\text{banana} \mid p) = p(\text{ba}) p(\text{an}) p(\text{na}) p(\text{an}) p(\text{na})$$

$$= p(\text{ba}) p(\text{an})^2 p(\text{na})^2 .$$

Maximising the log-likelihood is easier

$$\log L(p) = \sum_{(u,v)\in\mathcal{B}} c_{uv}\log p(u,v)$$

and must be done following certain constraints

$$\sum_{(u,v)\in\mathcal{B}} p(u,v)=1 \quad\text{and}\quad p(u,v)\ge 0$$

in order to have a valid probability distribution. In order to maximise $\mathcal{L}(p)$ while simultaneously satisfying the constraints we form the Lagrangian

$$\mathcal{L}(p,\lambda) = \sum_{(u,v)} c_{uv}\log p(u,v) + \lambda!\left(1 - \sum_{(u,v)} p(u,v)\right),$$

and optimise $\mathcal{L}(p,\lambda)$ such that

$$\frac{\partial \mathcal{L}}{\partial p(u,v)} = \frac{\partial \mathcal{L}}{\partial \lambda} = 0.$$

From which we find

$$\frac{\partial \mathcal{L}}{\partial p(u,v)} = \frac{c_{uv}}{p(u,v)} - \lambda = 0,$$

and

$$\frac{\partial \mathcal{L}}{\partial \lambda} = 1 - \sum_{(u,v)} p(u,v) = 0 \quad\Rightarrow\quad \sum_{(u,v)} p(u,v) = 1.$$

so

$$\lambda \sum_{(u,v)} p(u,v) = \sum_{(u,v)} c_{uv} \Rightarrow \lambda = N,$$

where $N = \sum_{(u,v)} c_{uv}$ is the total number of bigrams.

Finally, the maximum likelihood estimate (MLE) is

$$\hat{p}(u,v) = \frac{c_{uv}}{N}.$$

So, under our naive assumptions, bigram bag of words is the best predictor.

Single-layer neural network for bigrams

The code snippet below trains a single layer neural network to predict bigrams. The models is a single weights matrix which predicts the second token in each bigram pair in the training set. By iteratively adjusting the weights through back propagation in order minimising the log likelihood, optimising the neural network recovers the exact same bigram matrix as in the previous bigram bag of words example. Doing so is far less efficient that in the previous section, but the methodology can be modified slightly in order to do much better than bag of words.

for word in words:
    word = ["."] + list(word) + ["."]
    for char1, char2 in zip(word, word[1:]):
        i, j = char_to_index[char1], char_to_index[char2]
        prev_char_idx.append(i)
        next_char_idx.append(j)

# Jax arrays track gradients for auto differentiation.
from jax import numpy as jnp
# Arrays of indices which form the training set.
# e.g. next character corresponding to prev_char_idx[10] is next_char_idx[10]
prev_char_idx = jnp.array(prev_char_idx)
next_char_idx = jnp.array(next_char_idx)

# Encodes token as vector of zeros with a 1 at location corresponding to the token
# e.g. a=Array([0, 1, 0, ...]), b=Array([0, 0, 1, ...])
# Array dims = number of bigrams x size of character set (a-z + .).
prev_char_encoded = jax.nn.one_hot(prev_char_idx, 27, dtype=float)
next_char_encoded = jax.nn.one_hot(next_char_idx, 27, dtype=float)

def loss_fn(W, x_encoded):
    logits = x_encoded @ W
    counts = jnp.exp(logits)
    probabilities = counts / counts.sum(1, keepdims=True)
    # Negative log-likelihood with regularisation term.
    loss = -jnp.log(probabilities[jnp.arange(len(logits)), next_char_idx]).mean() + 0.01 * (W**2).mean()
    return loss

# Randomly initialise weights.
W = jax.random.normal(key, (27, 27))

# Training loop.
learning_rate = 50
for _ in range(200):
    grad = jax.grad(loss_fn)(W, prev_char_encoded)
    W += -learning_rate * grad

Multilayer perceptron

In the previous examples our weights matrix had 729 tunable parameters (27 x 27), and in this section we use a multilayer perceptron with over an order of magnitude more. Each character is represented by a 5D embedding, and next characters are predicted from 3 characters of context. This steps up the contextual awareness of the model, and means a 15-dimensional vector is fed into the neutral network in order to get a next character prediction. The hidden layer consists of a 15x300 weights matrix (plus the bias) which means this layer 'eats' a 15-dimensional vector and spits out a 300-dimensional vector. The output layer projects this down into 27 dimensions which is used to predict a probability distribution over the character set.

Our MLP has 13,062 parameters in total and JAX makes tracking gradients and optimising the loss the loss over 228,146 examples of names simultaneously light work on my MacBook (although I had to pin my jax version in order for it to run on Apple Silicon). The contextual awareness of the model has increased, along with its ability to learn non-linear relationships between inputs. However, its outputs still aren't very good: our 'make more' machine makes things (e.g. names) which are still quite different form examples it has been trained on.


import flax
import jax
from jax import numpy as jnp
import optax

block_size = 3
contexts, targets = [], []
for word in words:
    context = [0] * block_size
    for char in word + ".":
        index = char_to_index[char]
        contexts.append(context)
        targets.append(index)
        context = context[1:] + [index]

# Each "context" is 3 characters represented as 3 integer character codes.
# emma -> ..., ..e, .em, emm, mma
contexts = jnp.array(contexts) # shape=(N, 3)
# Each next "target" character is a single integer code.
# emma ->   e,   m,   m,   a,   .
targets = jnp.array(targets) # shape=(N,)

@flax.struct.dataclass
class Params:
    character_embeddings: jax.Array
    weights1: jax.Array
    bias1: jax.Array
    weights2: jax.Array
    bias2: jax.Array

key = jax.random.key(0)
key, key_char_emb, key_weights1, key_bias1, key_weights2, key_bias2 = jax.random.split(key, 6)
params = Params(
    # Learn a 5D representation for each of the 27 character codes
    character_embeddings=jax.random.normal(key_char_emb, shape=(27, 5)),
    # Context vectors are of size 15 (3 chars, each with 5D representation)
    # Hidden layer
    weights1=jax.random.normal(key_weights1, (15, 300)),
    bias1=jax.random.normal(key_bias1, 300),
    # Output layer
    weights2=jax.random.normal(key_weights2, (300, 27)),
    bias2=jax.random.normal(key_bias2, 27),
)

def forward(params: Params, contexts: jax.Array) -> jax.Array:
    # Replace each character code in each context with its 5D representation.
    context_embeddings = params.character_embeddings[contexts]
    # context_embeddings.reshape(-1, 15) @ W1 gives 300 dim hidden representation for each of N contexts.
    hidden = jnp.tanh(
        context_embeddings.reshape(-1, 15) @ params.weights1 + params.bias1
    )
    logits = hidden @ params.weights2 + params.bias2  # (N, 27)
    return logits

def loss_fn(params: Params, contexts: jax.Array, targets: jax.Array) -> jax.Array:
    logits = forward(params, contexts)
    loss = optax.losses.softmax_cross_entropy_with_integer_labels(
        logits, targets
    ).mean()
    return loss

@jax.jit
def train_step(
    params: Params, opt_state: optax.OptState, contexts: jax.Array, targets: jax.Array
):
    loss, grads = jax.value_and_grad(loss_fn)(params, contexts, targets)
    updates, opt_state = optimizer.update(grads, opt_state, params=params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Simple Stochastic Gradient Descent optimizer.
optimizer = optax.sgd(learning_rate=0.1)
opt_state = optimizer.init(params)

# Training loop.
for _ in range(100):
    params, opt_state, training_loss = train_step(params, opt_state, contexts, targets)
    print(training_loss)

Batching

Written 24/12/25