Home

Predicting text: from bigram bag of words to transformer

This is my write up of what I learned following Andrej Karpathy's excellent Makemore series on learning to build large language models. I take tangents in order to explore some aspects in more detail, but I don't provide full explanations for everything, nor the full code, which can be found here. I wrote my code with Jax rather than Pytorch, and show snippets below. As we progress through the article we build bigger and better models for text completion.

Bigram bag of words

A bag of words model is an extremely simple model which assigns the probability of a token to the fraction of tokens in the training set which are that token. A bigram bag of words goes a small step further by predicting pairs of characters based on the proportion of pairs for which token A is followed by token B. Consider a character-level (tokens are single characters) bigram bag of words model trained on the following text, focussing on the letter o.

The quick brown fox jumped over the lazy dog.

The letter o occurs three times and is followed by x, v, and g. A bag of words model would assign next character probabilities of $P(x \mid o) = P(v \mid o) = P(g \mid o) = 1/3$ and 0 otherwise.

bigram_to_count: dict[tuple[str, str], int] = {}
# words is list of common names: https://github.com/karpathy/makemore/blob/master/names.txt
for word in words:
    # Designate . as a special start/end character.
    word = ["."] + list(word) + ["."]
    for char1, char2 in zip(word, word[1:]):
        bigram = (char1, char2)
        bigram_to_count[bigram] = bigram_to_count.get(bigram, 0) + 1

import numpy as np
bigram_matrix = np.zeros((27, 27), dtype=int)
for (char1, char2), count in bigram_to_count.items():
    i, j = char_to_index[char1], char_to_index[char2]
    bigram_matrix[i, j] = count

Maximum likelihood estimation

Assuming that bigrams are independent and identically distributed random variables, the probability distribution with the maximum likelihood for the training data is the bigram bag of words model. To show this, we will find the probability distribution $p$ such that

$$\hat{p} = \arg\max_{p} ; P(\text{training data} \mid p)$$

by maximising the likelihood

$$L(p) = P((x_1,x_2), \dots, (x_{N-1},x_N) \mid p) = \prod_{i=1}^{N-1} p(x_i,x_{i+1})$$

$$= \prod_{(u,v) \in \mathcal{B}} p(u,v)^{c_{uv}}$$

where in the previous line bigrams $(u,v)$ in the bigram set $\mathcal{B}$ are grouped together, e.g.

$$P(\text{banana} \mid p) = p(\text{ba}) p(\text{an}) p(\text{na}) p(\text{an}) p(\text{na})$$

$$= p(\text{ba}) p(\text{an})^2 p(\text{na})^2 .$$

Maximising the log-likelihood is easier

$$\log L(p) = \sum_{(u,v)\in\mathcal{B}} c_{uv}\log p(u,v)$$

and must be done following certain constraints

$$\sum_{(u,v)\in\mathcal{B}} p(u,v)=1 \quad\text{and}\quad p(u,v)\ge 0$$

in order to have a valid probability distribution. In order to maximise $\mathcal{L}(p)$ while simultaneously satisfying the constraints we form the Lagrangian

$$\mathcal{L}(p,\lambda) = \sum_{(u,v)} c_{uv}\log p(u,v) + \lambda!\left(1 - \sum_{(u,v)} p(u,v)\right),$$

and optimise $\mathcal{L}(p,\lambda)$ such that

$$\frac{\partial \mathcal{L}}{\partial p(u,v)} = \frac{\partial \mathcal{L}}{\partial \lambda} = 0.$$

From which we find

$$\frac{\partial \mathcal{L}}{\partial p(u,v)} = \frac{c_{uv}}{p(u,v)} - \lambda = 0,$$

and

$$\frac{\partial \mathcal{L}}{\partial \lambda} = 1 - \sum_{(u,v)} p(u,v) = 0 \quad\Rightarrow\quad \sum_{(u,v)} p(u,v) = 1.$$

so

$$\lambda \sum_{(u,v)} p(u,v) = \sum_{(u,v)} c_{uv} \Rightarrow \lambda = N,$$

where $N = \sum_{(u,v)} c_{uv}$ is the total number of bigrams.

Finally, the maximum likelihood estimate (MLE) is

$$\hat{p}(u,v) = \frac{c_{uv}}{N}.$$

So, under our naive assumptions, bigram bag of words is the best predictor.

Single-layer neural network for bigrams

The code snippet below trains a single layer neural network to predict bigrams. The models is a single weights matrix which predicts the second token in each bigram pair in the training set. By iteratively adjusting the weights through back propagation in order minimising the log likelihood, optimising the neural network recovers the exact same bigram matrix as in the previous bigram bag of words example. Doing so is far less efficient that in the previous section, but the methodology can be modified slightly in order to do much better than bag of words.

for word in words:
    word = ["."] + list(word) + ["."]
    for char1, char2 in zip(word, word[1:]):
        i, j = char_to_index[char1], char_to_index[char2]
        prev_char_idx.append(i)
        next_char_idx.append(j)

# Jax arrays track gradients for auto differentiation.
from jax import numpy as jnp
# Arrays of indices which form the training set.
# e.g. next character corresponding to prev_char_idx[10] is next_char_idx[10]
prev_char_idx = jnp.array(prev_char_idx)
next_char_idx = jnp.array(next_char_idx)

# Encodes token as vector of zeros with a 1 at location corresponding to the token
# e.g. a=Array([0, 1, 0, ...]), b=Array([0, 0, 1, ...])
# Array dims = number of bigrams x size of character set (a-z + .).
prev_char_encoded = jax.nn.one_hot(prev_char_idx, 27, dtype=float)
next_char_encoded = jax.nn.one_hot(next_char_idx, 27, dtype=float)

def loss_fn(W, x_encoded):
    logits = x_encoded @ W
    counts = jnp.exp(logits)
    probabilities = counts / counts.sum(1, keepdims=True)
    # Negative log-likelihood with regularisation term.
    loss = -jnp.log(probabilities[jnp.arange(len(logits)), next_char_idx]).mean() + 0.01 * (W**2).mean()
    return loss

# Randomly initialise weights.
W = jax.random.normal(key, (27, 27))

# Training loop.
learning_rate = 50
for _ in range(200):
    grad = jax.grad(loss_fn)(W, prev_char_encoded)
    W += -learning_rate * grad

Multilayer perceptron

In the previous examples our weights matrix had 729 tunable parameters (27 x 27), and in this section we use a multilayer perceptron with over an order of magnitude more. Each character is represented by a 5D embedding, and next characters are predicted from 3 characters of context. This steps up the contextual awareness of the model, and means a 15-dimensional vector is fed into the neutral network (15 x 3, flattened) in order to get a next character prediction. The hidden layer consists of a 15 x 300 weights matrix plus bias without a non-linear activation function, which means this layer 'eats' a 15-dimensional vector and spits out a 300-dimensional vector. The output layer projects this down into 27 dimensions which is used to predict a probability distribution over the character set.

Our MLP has 13,062 parameters in total and JAX makes tracking gradients and optimising the loss the loss over 228,146 examples of names light work on my MacBook. The contextual awareness of the model has increased, but its outputs still aren't very good: our 'make more' machine makes things (such as names) which are still quite different form examples it has been trained on.


import flax
import jax
from jax import numpy as jnp
import optax

block_size = 3
contexts, targets = [], []
for word in words:
    context = [0] * block_size
    for char in word + ".":
        index = char_to_index[char]
        contexts.append(context)
        targets.append(index)
        context = context[1:] + [index]

# Each "context" is 3 characters represented as 3 integer character codes.
# emma -> ..., ..e, .em, emm, mma
contexts = jnp.array(contexts) # shape=(N, 3)
# Each next "target" character is a single integer code.
# emma ->   e,   m,   m,   a,   .
targets = jnp.array(targets) # shape=(N,)

@flax.struct.dataclass
class Params:
    character_embeddings: jax.Array
    weights1: jax.Array
    bias1: jax.Array
    weights2: jax.Array
    bias2: jax.Array

key = jax.random.key(0)
key, key_char_emb, key_weights1, key_bias1, key_weights2, key_bias2 = jax.random.split(key, 6)
params = Params(
    # Learn a 5D representation for each of the 27 character codes
    character_embeddings=jax.random.normal(key_char_emb, shape=(27, 5)),
    # Context vectors are of size 15 (3 chars, each with 5D representation)
    # Hidden layer
    weights1=jax.random.normal(key_weights1, (15, 300)),
    bias1=jax.random.normal(key_bias1, 300),
    # Output layer
    weights2=jax.random.normal(key_weights2, (300, 27)),
    bias2=jax.random.normal(key_bias2, 27),
)

def forward(params: Params, contexts: jax.Array) -> jax.Array:
    # Replace each character code in each context with its 5D representation.
    context_embeddings = params.character_embeddings[contexts]
    # context_embeddings.reshape(-1, 15) @ W1 gives 300 dim hidden representation for each of N contexts.
    hidden = jnp.tanh(
        context_embeddings.reshape(-1, 15) @ params.weights1 + params.bias1
    )
    logits = hidden @ params.weights2 + params.bias2  # (N, 27)
    return logits

def loss_fn(params: Params, contexts: jax.Array, targets: jax.Array) -> jax.Array:
    logits = forward(params, contexts)
    loss = optax.losses.softmax_cross_entropy_with_integer_labels(
        logits, targets
    ).mean()
    return loss

@jax.jit
def train_step(
    params: Params, opt_state: optax.OptState, contexts: jax.Array, targets: jax.Array
):
    loss, grads = jax.value_and_grad(loss_fn)(params, contexts, targets)
    updates, opt_state = optimizer.update(grads, opt_state, params=params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Simple Stochastic Gradient Descent optimizer.
optimizer = optax.sgd(learning_rate=0.1)
opt_state = optimizer.init(params)

# Training loop.
for _ in range(100):
    params, opt_state, training_loss = train_step(params, opt_state, contexts, targets)
    print(training_loss)

Training neural networks in practice: batching, initialisation and batch normalisation.

An MLP with non-linear activation functions and enough parameters is a universal approximator of continuous functions with compact domains. This assures there is a conceivable model with zero loss somewhere in the space of models, but tells us nothing about how to find it. Large models introduce certain hurdles that require a modification of our strategy thus far.

Batching

Our modelling paradigm thus far is to construct an average loss function $L$ over every example in the training data that is differentiable by the model weights.

$$L = \frac{1}{N} \sum_i^N L_i$$

The cost of computing $L$ and its derivatives increases linearly with the number of training examples $N$. In practice, neural networks are almost always trained using a technique called mini-batching. Random subsets of training examples called batches are sampled from the full set and used to compute approximations of the true gradient. Fortunately for us, this approximation is an unbiased estimator, meaning the gradient we compute will be probabilistically scattered around the true value, with a scatter that decreases as the batch size increases.

Suppose we have N training examples with batch sizes of M, these are the names for the different training paradigms.

Initialisation

Suppose the parameters of an MLP are each drawn from a Gaussian with variance $\sigma^2$. The pre-activations $f_k$, at layer $k$, are

$$f_k = \beta_k + \Omega_k \mathbf a [f_{k-1}]$$

with bias $\beta$, weights matrix $\Omega$ and activation function $\mathbf a$, and $\Omega_k \mathbf a$ implies a matrix-vector product. Let's consider the variance of the activations:

$$\text{var}(f_k) = \text{var}(\beta_k) + \text{var}(\Omega_k \mathbf a[f_{k-1}])$$

The effect of the activation function on the variance of the pre-activations from the previous layer $f_{k-1}$ depends on the activation function. Each activation function may squish its inputs differently e.g. ReLU transforms the variance by a factor of $1/2$. Multiplying by the weights matrix introduces a variance of $\sigma^2$ for each of the $n$ inputs to the neuron, where $n$ is known as fan-in. Thus, the variance has increased by

$$ \mathrm{var}(f_k)= \mathrm{var}(\beta_k) + g \times n \times \sigma^2 \mathrm{var}(f_{k-1}) $$

where $g$ is some gain introduced by $a$. Some bad things happen can happen if variance becomes too large or too small:

There are a number of initialisation techniques which combat these problems. A simple, effective, and popular one is Kaiming-He initialisation which:

Batch normalisation

Re-centering and re-scaling is a differentiable transformation. As in, for some transformation

$$y = \frac{x - \mu}{\sigma},$$

we can compute $\text d y/ \text d \sigma$, $\text dy/ \text d\mu$. If we think our activations should be distributed according to some properties, batch normalisation says why don't we just transform them as such!

Batch normalisation introduces one or more pseudo layers, I'm calling them pseudo layers because they're quite different to the layers in, say, a MLP. Each batch normalisation layer standardises a distribution of activations, and then performs an affine transformation on those activations where the scale and shift are explicit model parameters. To explain this further, let's first consider the simpler case of full-batch normalisation.

Imagine a tensor of activations over the entire training dataset $\mathbf a$ is flowing into a batch normalisation layer. During training, the layer will standardise the activations according to the mean and spread of each feature over all of the training examples.

$$ \mathbf a_\text{standard} = \frac{\mathbf a_\text{in} - \mu}{\sigma} $$

It will then perform the transformation

$$ \mathbf a_\text{out} = \gamma \mathbf a_\text{standard} + \beta, $$

where $\gamma$ and $\beta$ are trainable parameters (they are part of our computational graph, we compute $d L / d \gamma$, $d L / d \beta$ during backpropogation). That's different to $\sigma$ and $\mu$ which are determined during inference and are not updated during training. Some additional complication comes in mini-batching. There, during training our batch norm layer will be computing

$$ \mathbf a_\text{standard} = \frac{\mathbf a_\text{in} - \mu_B}{\sigma_B}. $$

Some non-deterministic weirdness has been introduced because our batch samples are interdependet, and we are no longer computing an unbiased estimator of the true loss gradient. However, in practice all of this tends to be minor or actually beneficial for learning. Note that we still need $\mu$ and $\sigma$ at inference. In the case of full-batch we would use the values we computed over the whole training set, but in batch normalisation we make running estimates $\hat \mu$ and $\hat \sigma$.

Backpropogation is a leaky abstraction

Written 24/12/25