Much attention has been given to analysing the forward pass FLOPs in GPT models (1, 2, 3), whereas there has been a lack of similar detailed accounting for the backward pass, which is typically quoted as having ~2x the FLOPs. Here, we’ll untangle this estimate by deriving step-by-step the gradient expressions for each operation in the network.

Model Architecture Overview

A GPT-2 style model accepts as input a batch of $b$ sequences of $s$ tokens and embeds each token using a lookup into a matrix of dimension $(V, h)$, where $V$ is the vocabulary size and $h$ is the embedding dimension. These token embeddings are summed with positional embeddings and then passed through a stack of $n_{layer}$ transformer layers. Each layer contains a multi-head attention module and a feed-forward network, both of which are preceded by a layer normalization and followed by a residual connection. The output of the final transformer layer undergoes an additional layer normalization and a linear transformation to produce logits for each token in the vocabulary. Finally, the loss $L$ is computed using the cross-entropy between the logits and the true token distribution.

Differentiable Functions and Chain Rule

A transformer is a composition of differentiable functions $f_1, f_2, …, f_n$ (like linear projections, softmax, layer normalization), meaning the entire model can be expressed as $L = f_1 \circ f_2 \circ … \circ f_n$. This compositional structure of differentiable functions means that the gradient of $L$ with respect to any intermediate value is well-defined, making the chain rule applicable. Thus, during backpropagation, to find $\frac{\partial L}{\partial x}$ where $x$ is the input to some function $f_i$, we can multiply $\frac{\partial L}{\partial y}$ (where $y$ is $f_i$’s output) by $\frac{\partial y}{\partial x}$ (the gradient of $f_i$ with respect to its input).

It’s important to note that these functions operate on tensors. For any tensor in the network, its gradient tensor will have the same shape, since we need a partial derivative with respect to the loss for each element.

Operation Classification

Data Movement is All You Need groups operations within a transformer by compute-intensity in the following buckets:

  • Tensor contractions: Matrix multiplications which manifest in the attention and MLP blocks, as well as the language modelling head
  • Statistical normalizations: Softmax, log-softmax and layer normalization, which involve reduction operations over the hidden dimension
  • Element-wise operators: Activation functions, attention scaling and residual connections

Tensor Contractions

The operation that dominates most of the FLOPs distribution is matrix multiplication, so let’s begin by differentiating it first. Consider the general vector-matrix multiplication $z = xW$ where $z \in \mathbb{R}^{1 \times P}, x \in \mathbb{R}^{1 \times N}, W \in \mathbb{R}^{N \times P}$:

\[\begin{bmatrix} z_1 & z_2 & \cdots & z_P \end{bmatrix} = \begin{bmatrix} x_1 & x_2 & \cdots & x_N \end{bmatrix} \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1P} \\ w_{21} & w_{22} & \cdots & w_{2P} \\ \vdots & \vdots & \ddots & \vdots \\ w_{N1} & w_{N2} & \cdots & w_{NP} \end{bmatrix}\]

The $i$-th element in the output vector is the dot product of the input vector with the $i$-th column in the weight matrix:

\[z_i = x_1w_{1i} + x_2w_{2i} + ... + x_Nw_{Ni} = \sum_{r=1}^N x_rw_{ri}\]

If we consider the first element $x_1$ in the input vector, we can enumerate its contributions to each element in the output vector:

\[\begin{align*} z_1 &= x_1w_{11} + ... \\ z_2 &= x_1w_{12} + ... \\ &\vdots \\ z_P &= x_1w_{1P} + ... \end{align*}\]

The loss $L$ is constructed through some transformation of the output vector $z$. Since $x_1$ contributes to every element of $z$, we must sum across all paths through which $x_1$ influences $L$ via $z$. By the chain rule, the total gradient for $x_1$ is:

\[\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial z_1}\frac{\partial z_1}{\partial x_1} + \frac{\partial L}{\partial z_2}\frac{\partial z_2}{\partial x_1} + ... + \frac{\partial L}{\partial z_P}\frac{\partial z_P}{\partial x_1} = \frac{\partial L}{\partial z_1}w_{11} + \frac{\partial L}{\partial z_2}w_{12} + ... + \frac{\partial L}{\partial z_P}w_{1P}\]

This can be generalized for the $i$-th element in $x$ such that:

\[\frac{\partial L}{\partial x_i} = \begin{bmatrix} \frac{\partial L}{\partial z_1} & \frac{\partial L}{\partial z_2} & \cdots & \frac{\partial L}{\partial z_P} \end{bmatrix} \begin{bmatrix} w_{i1} \\ w_{i2} \\ \vdots \\ w_{iP} \end{bmatrix}\]

We can then stack these individual dot products as a single vector-matrix multiply:

\[\begin{bmatrix} \frac{\partial L}{\partial x_1} & \frac{\partial L}{\partial x_2} & \cdots & \frac{\partial L}{\partial x_N} \end{bmatrix} = \begin{bmatrix} \frac{\partial L}{\partial z_1} & \frac{\partial L}{\partial z_2} & \cdots & \frac{\partial L}{\partial z_P} \end{bmatrix} \begin{bmatrix} w_{11} & w_{21} & \cdots & w_{N1} \\ w_{12} & w_{22} & \cdots & w_{N2} \\ \vdots & \vdots & \ddots & \vdots \\ w_{1P} & w_{2P} & \cdots & w_{NP} \end{bmatrix}\]

Therefore:

\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial z}W^T\]

Now, let’s consider the gradient with respect to the matrix $W$. If we examine the first column of $W$, we can see that it only contributes to the computation of $z_1$. Therefore, the contribution of each element in the first column to the loss $L$ must be mediated through $z_1$, giving:

\[\begin{align*} \frac{\partial L}{\partial w_{j1}} &= \frac{\partial L}{\partial z_1}\frac{\partial z_1}{\partial w_{j1}} \\ &= \frac{\partial L}{\partial z_1}x_j \end{align*}\]

Or as a vector multiplication between the column vector of inputs and the scalar gradient of loss with respect to $z_1$:

\[\begin{bmatrix} \frac{\partial L}{\partial w_{11}} \\ \frac{\partial L}{\partial w_{21}} \\ \vdots \\ \frac{\partial L}{\partial w_{N1}} \end{bmatrix} = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_N \end{bmatrix} \frac{\partial L}{\partial z_1}\]

This pattern holds for all columns in $W$, where the $i$-th column’s gradients are computed by multiplying the input vector by the corresponding output gradient $\frac{\partial L}{\partial z_i}$. We can thus collapse these individual vector-scalar multiplies into an outer product:

\[\begin{bmatrix} \frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{12}} & \cdots & \frac{\partial L}{\partial w_{1P}} \\ \frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{22}} & \cdots & \frac{\partial L}{\partial w_{2P}} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial w_{N1}} & \frac{\partial L}{\partial w_{N2}} & \cdots & \frac{\partial L}{\partial w_{NP}} \end{bmatrix} = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_N \end{bmatrix} \begin{bmatrix} \frac{\partial L}{\partial z_1} & \frac{\partial L}{\partial z_2} & \cdots & \frac{\partial L}{\partial z_P} \end{bmatrix}\]

Therefore:

\[\frac{\partial L}{\partial W} = x^T\frac{\partial L}{\partial z}\]

The extension to matrix-matrix multiplication follows naturally. If we now, consider the product $Z = XW$ where $Z \in \mathbb{R}^{M \times P}, X \in \mathbb{R}^{M \times N}, W \in \mathbb{R}^{N \times P}$:

\[\begin{bmatrix} z_{11} & z_{12} & \cdots & z_{1P} \\ z_{21} & z_{22} & \cdots & z_{2P} \\ \vdots & \vdots & \ddots & \vdots \\ z_{M1} & z_{M2} & \cdots & z_{MP} \end{bmatrix} = \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1N} \\ x_{21} & x_{22} & \cdots & x_{2N} \\ \vdots & \vdots & \ddots & \vdots \\ x_{M1} & x_{M2} & \cdots & x_{MN} \end{bmatrix} \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1P} \\ w_{21} & w_{22} & \cdots & w_{2P} \\ \vdots & \vdots & \ddots & \vdots \\ w_{N1} & w_{N2} & \cdots & w_{NP} \end{bmatrix}\]

We can view this as a batch of $M$ vector-matrix multiplications, where each row in $X$ is multiplied by $W$ to produce a corresponding row in $Z$. For any row $i$ in the input matrix $X$, we can apply our previous result:

\[\begin{bmatrix} \frac{\partial L}{\partial x_{i1}} & \frac{\partial L}{\partial x_{i2}} & \cdots & \frac{\partial L}{\partial x_{iN}} \end{bmatrix} = \begin{bmatrix} \frac{\partial L}{\partial z_{i1}} & \frac{\partial L}{\partial z_{i2}} & \cdots & \frac{\partial L}{\partial z_{iP}} \end{bmatrix} W^T\]

Since each row’s gradient computation is independent, we can stack these equations for all rows $i \in {1, …, M}$. Therefore, we can express the gradient with respect to the input matrix $X$ as:

\[\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Z}W^T\]

We also saw that for any row $i$ of $X$, its contribution to $\frac{\partial L}{\partial W}$ is:

\[\left.\frac{\partial L}{\partial W}\right|_i = \text{row}_i(X)^T\begin{bmatrix}\frac{\partial L}{\partial z_{i1}} & \frac{\partial L}{\partial z_{i2}} & \cdots & \frac{\partial L}{\partial z_{iP}}\end{bmatrix}\]

Since each weight participates in $M$ such computations (one for each row of $Z$), the total gradient is the sum of these contributions:

\[\frac{\partial L}{\partial W} = \sum_{i=1}^M \left.\frac{\partial L}{\partial W}\right|_i\]

This summation can be performed as a single matrix multiplication:

\[\frac{\partial L}{\partial W} = X^T\frac{\partial L}{\partial Z}\]

Extension to Tensor Operations in Transformers

In transformer architectures, we encounter two key patterns when performing matrix multiplications with tensors:

Weight-to-Activation Multiplications

These operations involve multiplying a 3D activation tensor with a 2D weight matrix. The operations that fit this profile are:

  • the query, key, value projections
  • the output projection after attention
  • the up and down projections in the MLP
  • the linear projection to logits in the language modelling head

We can decompose each of these operations into a series of matrix multiplications between each 2D slice of the activation tensor and the weight matrix.

Let’s take the language modelling head with input $X \in \mathbb{R}^{b \times s \times h}$, weight $W \in \mathbb{R}^{h \times V}$ and output $Z \in \mathbb{R}^{b \times s \times V}$ as an example. Since the forward pass is $b$ matrix multiplications, where each slice $X_i \in \mathbb{R}^{s \times h}$ is multiplied with $W$ to produce output slice $Z_i \in \mathbb{R}^{s \times V}$, during the backward pass, the gradient of the loss with respect to $X$ is computed slice by slice:

\[\frac{\partial L}{\partial X_i} = \frac{\partial L}{\partial Z_i}W^T\]

The gradients for all $b$ slices are then stacked to construct the input gradients tensor $\frac{\partial L}{\partial X} \in \mathbb{R}^{b \times s \times h}$

For the weight gradients, since $W$ participates in all $b$ matrix multiplications, its gradient accumulates updates from all slices:

\[\frac{\partial L}{\partial W} = \sum_{i=1}^b X_i^T\frac{\partial L}{\partial Z_i}\]

This same approach applies to all weight-to-activation multiplications, with the only differences being the input and output dimensions of the tensors involved.

Activation-to-Activation Multiplications

These occur in multi-head attention, specifically during the query @ key and attn @ value operations.

We can view the latter operation, between $\text{attn} \in \mathbb{R}^{b \times n_h \times s \times s}$ and $\text{value} \in \mathbb{R}^{b \times n_h \times s \times d}$, as $b \times n_h$ independent matrix multiplications, each involving an $s \times s$ attention matrix and an $s \times d$ value matrix. Unlike weight-to-activation multiplications where the weight matrix is used across the entire batch and so we must accumulate across the batch dimension for $\frac{\partial L}{\partial W}$, here each multiplication involves unique matrix slices. The gradients are thus computed independently for each slice and then stacked across all heads and examples in the batch to construct the final gradient tensors.

FLOPs Analysis

For a matrix multiplication between matrices of dimensions $(m,n)$ and $(n,p)$, each element in the output matrix requires $n$ multiplications and $n-1$ additions. The $mp$ output elements thus require $mnp$ multiplications and $mp(n-1)$ additions, totaling ~$2mnp$ FLOPs.

Note: For weight gradient computations, instead of performing $b$ separate matrix multiplications and summing their products element-wise, we reshape the tensors to perform a single, larger matrix multiplication.

Operation Required Tensors Backward Pass Operations Total FLOPs
lm_head upstream_grad: $(b,s,V)$
fwd_input: $(b,s,h)$
weight: $(h,V)$
1. $b \times [(s,V)@(V,h)]$
2. $(h,bs)@(bs,V)$
$4bshV$
mlp_down upstream_grad: $(b,s,h)$
fwd_input: $(b,s,4h)$
weight: $(4h,h)$
1. $b \times [(s,h)@(h,4h)]$
2. $(4h,bs)@(bs,h)$
$16bsh^2$
mlp_up upstream_grad: $(b,s,4h)$
fwd_input: $(b,s,h)$
weight: $(h,4h)$
1. $b \times [(s,4h)@(4h,h)]$
2. $(h,bs)@(bs,4h)$
$16bsh^2$
attn_out upstream_grad: $(b,s,h)$
fwd_input: $(b,s,h)$
weight: $(h,h)$
1. $b \times [(s,h)@(h,h)]$
2. $(h,bs)@(bs,h)$
$4bsh^2$
attn@value upstream_grad: $(b,n_h,s,d)$
attn: $(b,n_h,s,s)$
value: $(b,n_h,s,d)$
1. $bn_h \times [(s,d)@(d,s)]$
2. $bn_h \times [(s,s)@(s,d)]$
$4bn_hs^2d$
query@key upstream_grad: $(b,n_h,s,s)$
query: $(b,n_h,s,d)$
key: $(b,n_h,s,d)$
1. $bn_h \times [(s,s)@(s,d)]$
2. $bn_h \times [(d,s)@(s,s)]$
$4bn_hs^2d$
qkv_proj upstream_grad: $(b,s,3h)$
fwd_input: $(b,s,h)$
weight: $(h,3h)$
1. $b \times [(s,3h)@(3h,h)]$
2. $(h,bs)@(bs,3h)$
$12bsh^2$

where:

  • $n_h$: number of attention heads
  • $d$: head dimension ($h/n_h$)

Statistical Normalizations

Softmax

The standard softmax function for $x \in \mathbb{R}^{1 \times N}$ is defined as:

\[z_i = \text{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}}\]

Examining $x_1$, we can see it has two distinct pathways through which it influences the vector $z$ and by extension the loss $L$:

  1. Direct contribution to corresponding output $z_1$ through both the numerator $e^{x_1}$ and the normalization term in the denominator $\sum_{j=1}^N e^{x_j}$

  2. Indirect contribution to all other outputs $z_2, z_3, \dots, z_N$ via the shared denominator $\sum_{j=1}^N e^{x_j}$

Thus, when computing the gradient of $\mathcal{L}$ with respect to any $x_i$ we must account for both pathways as follows:

\[\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial z_i}\frac{\partial z_i}{\partial x_i} + \sum_{j \neq i} \frac{\partial L}{\partial z_j}\frac{\partial z_j}{\partial x_i}\]

To find these partial derivatives explicitly, we’ll utilize the quotient rule. For convenience, let $S = \sum_j e^{x_j}$. Then:

Case 1: Self term:

\[\begin{align*} \frac{\partial z_i}{\partial x_i} &= \frac{\partial}{\partial x_i} \left(\frac{e^{x_i}}{S}\right) \\ &= \frac{e^{x_i} \cdot S - e^{x_i} \cdot e^{x_i}}{S^2} \\ &= \frac{e^{x_i}}{S} - \frac{(e^{x_i})^2}{S^2} \\ &= z_i - z_i^2 \end{align*}\]

Case 2: Cross terms ($j \neq i$):

\[\begin{align*} \frac{\partial z_j}{\partial x_i} &= \frac{\partial}{\partial x_i} \left(\frac{e^{x_j}}{S}\right) \\ &= \frac{0 \cdot S - e^{x_j} \cdot e^{x_i}}{S^2} \\ &= -\frac{e^{x_j}}{S} \cdot \frac{e^{x_i}}{S} \\ &= -z_j z_i \end{align*}\]

Plugging these expressions back into our formula for $\frac{\partial L}{\partial x_i}$:

\[\begin{align*} \frac{\partial L}{\partial x_i} &= \frac{\partial L}{\partial z_i}[z_i(1-z_i)] + \sum_{j \neq i} \frac{\partial L}{\partial z_j}(-z_jz_i) \\ &= \frac{\partial L}{\partial z_i}z_i - \frac{\partial L}{\partial z_i}(z_i^2) - \sum_{j \neq i} \frac{\partial L}{\partial z_j}(z_jz_i) \\ &= \frac{\partial L}{\partial z_i}z_i - z_i\left(\frac{\partial L}{\partial z_i}z_i + \sum_{j \neq i} \frac{\partial L}{\partial z_j}z_j\right) \\ &= z_i\left(\frac{\partial L}{\partial z_i} - \sum_j \frac{\partial L}{\partial z_j}z_j\right) \end{align*}\]

For matrix input $X \in \mathbb{R}^{M \times N}$, softmax is applied independently to each row vector. Since there are no cross-row interactions in the computation, we can directly apply our vector gradient formula to each row.

For higher dimensional tensors, we can collapse the leading dimensions by stacking them row-wise into a 2D matrix. A tensor of shape $(K, M, N)$ can be viewed as a matrix of shape $(KM, N)$, allowing us to apply the same row-wise gradient computation.

FLOPs Analysis

The backward pass requires storing both the upstream gradient and the forward pass output (the attention scores), both of shape $(b, n_h, s, s)$. For each of the $b \times n_h \times s$ feature vectors of dimension $s$, we must:

  1. Compute the weighted sum $\sum_j \frac{\partial L}{\partial z_j}z_j$: $2s$ FLOPs

  2. Subtract this sum from each incoming gradient: $s$ FLOPs

  3. Multiply element-wise with the softmax outputs: $s$ FLOPs

This yields $4bn_hs^2$ FLOPs in total.

Note on Numerical Stability
In practice, to prevent floating points overflows and underflows, the numerically stable version of softmax is used:

\[z_i = \text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}\]

This blog post does a great job in explaining the intuition behind this numerical stability trick. Since subtracting $\max(x)$ from each input doesn’t affect the mathematical form of the softmax function, the gradients we derived without considering $\max(x)$’s dependence on $x_i$ remain correct.

Log-softmax Gradients

The log-softmax function for $x \in \mathbb{R}^{1 \times N}$, which is defined as

\[\text{logSoftmax}(x)_i = x_i - \log\sum_j e^{x_j}\]

is analytically equivalent to the composition of the log and softmax functions, serving as a practical workaround for numerical stability (see the blog post referenced earlier). Thus:

\[p_i = \text{logSoftmax}(x)_i = \log(\text{softmax}(x)_i)\]

Once again, when computing the gradient of $L$ with respect to any input $x_i$, we must consider both the direct effect through $p_i$ and the indirect effects through all other $p_j$:

\[\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial p_i}\frac{\partial p_i}{\partial x_i} + \sum_{j \neq i}\frac{\partial L}{\partial p_j}\frac{\partial p_j}{\partial x_i}\]

We can further decompose $\frac{\partial p_i}{\partial x_i}$ using the chain rule. Let $z_i = \text{softmax}(x)_i$, then:

\[\frac{\partial p_i}{\partial x_i} = \frac{\partial p_i}{\partial z_i}\frac{\partial z_i}{\partial x_i}\]

From our softmax gradient derivation, we know that $\frac{\partial z_i}{\partial x_i} = z_i(1-z_i)$ and $\frac{\partial z_j}{\partial x_i} = -z_jz_i$. Since $p_i = \log(z_i)$, we also have $\frac{\partial p_i}{\partial z_i} = \frac{1}{z_i}$. Therefore:

\[\begin{align*} \frac{\partial L}{\partial x_i} &= \frac{\partial L}{\partial p_i}\frac{\partial p_i}{\partial z_i}\frac{\partial z_i}{\partial x_i} + \sum_{j \neq i}\frac{\partial L}{\partial p_j}\frac{\partial p_j}{\partial z_j}\frac{\partial z_j}{\partial x_i} \\ &= \frac{\partial L}{\partial p_i}\frac{1}{z_i}z_i(1-z_i) + \sum_{j \neq i}\frac{\partial L}{\partial p_j}\frac{1}{z_j}(-z_jz_i) \\ &= \frac{\partial L}{\partial p_i}(1-z_i) - z_i\sum_{j \neq i}\frac{\partial L}{\partial p_j} \\ &= \frac{\partial L}{\partial p_i} - z_i\sum_{j}\frac{\partial L}{\partial p_j} \\ &= \frac{\partial L}{\partial p_i} - e^{p_i}\sum_{j}\frac{\partial L}{\partial p_j} \end{align*}\]

Due to the similar lack of cross-row dependencies, the extension to matrices and tensors follows the same principles as described for softmax.

FLOPs Analysis

The backward pass requires the upstream gradient and forward pass output (the log probabilities), both of shape $(b, s, V)$. For each of the $b \times s$ vectors of dimension $V$, we require:

  1. Sum of upstream gradient components: $V$ FLOPs
  2. Exponentiating each log-probability: $V$ FLOPs
  3. Multiplying each exponentiated value with the sum: $V$ FLOPs
  4. Subtracting from the incoming gradient: $V$ FLOPs

This totals $4bsV$ FLOPs for the entire operation

LayerNorm

Layer normalization operates on a vector $x \in \mathbb{R}^{1 \times N}$ with parameters $\gamma, \beta \in \mathbb{R}^{1 \times N}$. For each element $i$, the output is defined as:

\[y_i = \gamma_i \hat{x}_i + \beta_i\]

where $\hat{x}_i$ is the normalized input:

\[\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}\]

with $\mu$ being the mean and $\sigma^2$ being the variance of the inputs:

\[\begin{align*} &\mu = \frac{1}{N}\sum_{j=1}^N x_j \\[5pt] &\sigma^2 = \frac{1}{N}\sum_{j=1}^N (x_j - \mu)^2 \end{align*}\]

We’ll begin backpropagation by first obtaining the gradients for the learnable parameters $\gamma$ and $\beta$, as well as the intermediate normalized inputs $\hat{x}$. Since each output $y_i$ only depends on its corresponding $\hat{x}_i$, $\gamma_i$, and $\beta_i$, these gradients are straightforward applications of the chain rule:

\[\begin{align*} \frac{\partial L}{\partial \hat{x}_i} &= \frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i}\gamma_i \\[10pt] \frac{\partial L}{\partial \gamma_i} &= \frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial \gamma_i} = \frac{\partial L}{\partial y_i}\hat{x}_i \\[10pt] \frac{\partial L}{\partial \beta_i} &= \frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial \beta_i} = \frac{\partial L}{\partial y_i} \end{align*}\]

When backpropagating through $\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$, we now encounter similar reduction dynamics to softmax. To compute the gradient with respect to any $x_i$, we must consider both its direct effect on $\hat{x}_i$ and its indirect effects through the $\mu$ and $\sigma^2$ terms that appear in all other $\hat{x}_j$. This gives us:

\[\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial x_i} + \sum_{j \neq i} \frac{\partial L}{\partial \hat{x}_j}\frac{\partial \hat{x}_j}{\partial x_i}\]

By the quotient rule:

\[\begin{align*} \frac{\partial \hat{x}_i}{\partial x_i} &= \frac{(1 - \frac{\partial \mu}{\partial x_i})\sqrt{\sigma^2 + \epsilon} - (x_i - \mu)\frac{1}{2\sqrt{\sigma^2 + \epsilon}}\frac{\partial \sigma^2}{\partial x_i}}{(\sqrt{\sigma^2 + \epsilon})^2} \\ &= \frac{1}{\sqrt{\sigma^2 + \epsilon}}\left(1 - \frac{\partial \mu}{\partial x_i} - \frac{x_i - \mu}{2(\sigma^2 + \epsilon)}\frac{\partial \sigma^2}{\partial x_i}\right) \\ &= \frac{1}{\sqrt{\sigma^2 + \epsilon}}\left(1 - \frac{\partial \mu}{\partial x_i} - \frac{\hat{x}_i}{2\sqrt{\sigma^2 + \epsilon}}\frac{\partial \sigma^2}{\partial x_i}\right) \\[20pt] \frac{\partial \hat{x}_j}{\partial x_i} &= \frac{(- \frac{\partial \mu}{\partial x_i})\sqrt{\sigma^2 + \epsilon} - (x_j - \mu)\frac{1}{2\sqrt{\sigma^2 + \epsilon}}\frac{\partial \sigma^2}{\partial x_i}}{(\sqrt{\sigma^2 + \epsilon})^2} \\ &= \frac{1}{\sqrt{\sigma^2 + \epsilon}}\left(- \frac{\partial \mu}{\partial x_i} - \frac{x_j - \mu}{2(\sigma^2 + \epsilon)}\frac{\partial \sigma^2}{\partial x_i}\right) \\ &= \frac{1}{\sqrt{\sigma^2 + \epsilon}}\left(- \frac{\partial \mu}{\partial x_i} - \frac{\hat{x}_j}{2\sqrt{\sigma^2 + \epsilon}}\frac{\partial \sigma^2}{\partial x_i}\right) \end{align*}\]

For $\mu$, since we’re taking the derivative with respect to $x_i$, all other terms in the sum go to zero:

\[\frac{\partial \mu}{\partial x_i} = \frac{1}{N}\]

For $\sigma^2$, we split the sum to examine how $x_i$ flows through each term:

\[\begin{align*} \frac{\partial \sigma^2}{\partial x_i} &= \frac{1}{N} \frac{\partial}{\partial x_i} \left[\sum_j (x_j - \mu)^2\right] \\ &= \frac{1}{N} \frac{\partial}{\partial x_i} \left[(x_i - \mu)^2 + \sum_{j\neq i} (x_j - \mu)^2\right] \\ &= \frac{1}{N} \left[2(1 - \frac{\partial \mu}{\partial x_i})(x_i - \mu) + 2\sum_{j\neq i} (\frac{\partial \mu}{\partial x_i})(x_j - \mu)\right] \\ &= \frac{2}{N} \left[(1-\frac{1}{N})(x_i - \mu) + \sum_{j\neq i} (-\frac{1}{N})(x_j - \mu)\right] \\ &= \frac{2}{N} \left[x_i - \mu - \frac{1}{N} \sum_j x_j - \mu\right] \\ &= \frac{2(x_i - \mu)}{N} \end{align*}\]

Substituting these back in and noting that $\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$

\[\begin{align*} \frac{\partial \hat{x}_i}{\partial x_i} &= \frac{1}{\sqrt{\sigma^2 + \epsilon}}\left(1-\frac{1}{N} - \frac{\hat{x}_i^2}{N}\right) \\[10pt] \frac{\partial \hat{x}_j}{\partial x_i} &= \frac{1}{\sqrt{\sigma^2 + \epsilon}}\left(-\frac{1}{N} - \frac{\hat{x}_j\hat{x}_i}{N}\right) \end{align*}\]

Finally, putting it all together:

\[\begin{align*} \frac{\partial L}{\partial x_i} &= \frac{\partial L}{\partial \hat{x}_i}\left[\frac{1}{\sqrt{\sigma^2+\epsilon}}\left(1-\frac{1}{N} - \frac{\hat{x}_i^2}{N}\right)\right] + \sum_{j\neq i} \frac{\partial L}{\partial \hat{x}_j}\left[\frac{1}{\sqrt{\sigma^2+\epsilon}}\left(-\frac{1}{N} - \frac{\hat{x}_i\hat{x}_j}{N}\right)\right] \\ % &= \frac{1}{\sqrt{\sigma^2+\epsilon}}\left(\frac{n-1-\hat{x}_i^2}{n}\frac{\partial L}{\partial \hat{x}_i} + \sum_{j\neq i} \frac{-1-\hat{x}_i\hat{x}_j}{n}\frac{\partial L}{\partial \hat{x}_j}\right) \\ &= \frac{1}{\sqrt{\sigma^2+\epsilon}}\left(\frac{\partial L}{\partial \hat{x}_i} - \frac{1}{N}\frac{\partial L}{\partial \hat{x}_i} - \frac{\hat{x}_i^2}{N}\frac{\partial L}{\partial \hat{x}_i} + \sum_{j\neq i} \frac{1}{N}\frac{\partial L}{\partial \hat{x}_j} - \frac{\hat{x}_i\hat{x}_j}{N}\frac{\partial L}{\partial \hat{x}_j}\right) \\ &= \frac{1}{\sqrt{\sigma^2+\epsilon}}\left(\frac{\partial L}{\partial \hat{x}_i} - \frac{1}{N}\sum_j \frac{\partial L}{\partial \hat{x}_j} - \frac{\hat{x}_i}{N}\sum_j \frac{\partial L}{\partial \hat{x}_j}\hat{x}_j\right) \\ &= \frac{1}{\sqrt{\sigma^2+\epsilon}}\left(\frac{\partial L}{\partial y_i}\gamma_i - \frac{1}{N}\sum_j \frac{\partial L}{\partial y_j}\gamma_j - \frac{\hat{x}_i}{N}\sum_j \frac{\partial L}{\partial y_j}\gamma_j\hat{x}_j\right) \end{align*}\]

For matrix input $X \in \mathbb{R}^{M \times N}$, layer normalization is applied independently to each row vector. However, unlike softmax, the parameters $\gamma, \beta \in \mathbb{R}^{1 \times N}$ are shared across all rows. This means that while the gradient computation $\frac{\partial L}{\partial x_i}$ can be performed independently for each row (following the vector formula above), the gradients for $\gamma$ and $\beta$ must sum contributions from all rows:

\[\begin{align*} &\frac{\partial L}{\partial \gamma_i} = \sum_{r=1}^M \frac{\partial L}{\partial y_{ri}}\hat{x}_{ri} \\[5pt] &\frac{\partial L}{\partial \beta_i} = \sum_{r=1}^M \frac{\partial L}{\partial y_{ri}} \end{align*}\]

where $r$ indexes the rows of the input matrix.

As with softmax, tensors of shape $(K, M, N)$ can be viewed as matrices of shape $(KM, N)$, with the same principle of parameter sharing across all $KM$ rows.

FLOPs Analysis

LayerNorm’s backward pass requires maintaining several tensors - the upstream gradient $(b, s, h)$, forward input $(b, s, h)$, computed means $(b, s)$ and variances $(b, s)$, and the learnable parameters $\gamma$ $(h,)$ and $\beta$ $(h,)$. The computation breaks down into three main components:

Input Gradient
For each of the $b \times s$ vectors:

  1. Precompute intermediate values:
    • Evaluate $v_j = \frac{\partial L}{\partial y_j}\gamma_j$: $h$ FLOPs
    • Evaluate $S_1 = \sum_j v_j$: $h$ FLOPs
    • Evaluate $S_2 = \sum_j v_j\hat{x}_j$: $2h$ FLOPs
  2. Obtain each element gradient using $\frac{\partial L}{\partial x_i} = \frac{1}{\sqrt{\sigma^2 + \epsilon}}\left(v_i - \frac{S_1}{n} - \hat{x}_i \frac{S_2}{n}\right)$: $4h$ FLOPs
  3. Total: $8bsh$ FLOPs

Scale Parameter Gradient

  1. Element-wise multiplication: $bsh$ FLOPs
  2. Reduction across batch and sequence dimensions: $bsh$ FLOPs

Bias Parameter Gradient

  1. Reduction across batch and sequence dimensions: $bsh$ FLOPs

The total cost is $11bsh$ FLOPs

Element-wise Operations

These operations are applied independently to each element of the input tensor, meaning that each input element influences only the corresponding output element, with no interdependencies across elements. Each element’s gradients can thus be calculated independently, simplifying backpropagation greatly.

GELU

For a given element $x \in \mathbb{R}$, the GELU activation function is approximated as:

\[\text{GELU}(x) = \frac{1}{2}x\left(1 + \tanh\left[a(x+bx^3)\right]\right)\]

where $a = \sqrt{\frac{2}{\pi}}$, $b = 0.044715$.

Let $y = \text{GELU}(x)$. Then:

\[\begin{aligned} \frac{\partial L}{\partial x} &= \frac{\partial L}{\partial y}\frac{\partial y}{\partial x} \\ &= \frac{\partial L}{\partial y}\left[\frac{1}{2}\left(1 + \tanh\left[a(x+bx^3)\right]\right) + \frac{1}{2}x\left(\frac{\partial}{\partial x}\tanh\left[a(x+bx^3)\right]\right)\right] \\ &= \frac{\partial L}{\partial y}\left[\frac{1}{2}\left(1 + \tanh\left[a(x+bx^3)\right]\right) + \frac{1}{2}ax(1+3bx^2)\left(\text{sech}^2\left[a(x+bx^3)\right]\right)\right] \\ &= \frac{\partial L}{\partial y}\left[\frac{1}{2}\left(1 + \tanh\left[a(x+bx^3)\right]\right) + \frac{1}{2}ax(1+3bx^2)\left(1-\tanh^2\left[a(x+bx^3)\right]\right)\right] \\ \end{aligned}\]

FLOPs Analysis

We’ll analyze the cost based on PyTorch’s gelu_backward implementation. The backward pass requires both the upstream gradient and the forward pass input, both of shape $(b, s, 4h)$. For each element:

x_sq = x * x                           # 1 FLOP
x_cube = x_sq * x                      # 1 FLOP
inner = a * (x + b * x_cube)           # 3 FLOPs
tanh_inner = tanh(inner)               # 1 FLOP

left = 0.5 * x                         # 1 FLOP
right = 1 + tanh_inner                 # 1 FLOP

left_derivative = 0.5 * right          # 1 FLOP
tanh_derivative = 1 - tanh_inner * tanh_inner   # 2 FLOPs
inner_derivative = a * (1 + 3 * b * x_sq) # 4 FLOPs
right_derivative = left * tanh_derivative * inner_derivative # 2 FLOPs

grad = incoming_gradient * (left_derivative + right_derivative) # 2 FLOPs

For the complete tensor, this amounts to $76bsh$ FLOPs.

Attention Scaling

To prevent attention scores from growing too large in magnitude, they are scaled by $\frac{1}{\sqrt{d}}$. This avoids overly sharp softmax outputs, helping maintain stable gradients and ensuring smooth updates during training.

Given an attention score $x \in \mathbb{R}$, the scaled output is:

\[y = \frac{x}{\sqrt{d}}\]

The gradient is simply:

\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y}\frac{1}{\sqrt{d}}\]

FLOPs Analysis

For a tensor of shape $(b, n_h, s, s)$, backward pass involves one multiplication per element, leading to $bn_hs^2$ FLOPs.

Residual Connections

Residual connections add the input tensor to the resulting output of a module (e.g. attention or MLP). At the element level:

\[y = x + F(x)\]

During backpropagation, the upstream gradient is simply routed to the input gradients.

\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y}, \quad \frac{\partial L}{\partial F(x)} = \frac{\partial L}{\partial y}\]

FLOPs Analysis

Backpropagating through a residual connection incurs no additional arithmetic cost, as the upstream gradient is assigned directly to the input gradients.

Other Operations

Embeddings

The embedding process in a transformer maps token indices $[t_1, …, t_s]$, where $t_i \in \lbrace1, …, V\rbrace$, to token embeddings $X \in \mathbb{R}^{s \times h}$ via embedding matrix $E \in \mathbb{R}^{V \times h}$, combines them with position embeddings $P_{emb} \in \mathbb{R}^{s \times h}$ obtained from position indices $[1, …, s]$ via matrix $P \in \mathbb{R}^{s \times h}$, and sums them to produce the final embeddings $Y = X + P_{emb}$.

During backpropagation, the upstream gradient $\frac{\partial L}{\partial Y} \in \mathbb{R}^{s \times h}$ flows unchanged to both embeddings:

\[\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y}, \quad \frac{\partial L}{\partial P_{emb}} = \frac{\partial L}{\partial Y}\]

Token Embeddings

To examine how gradients are routed through the token embedding operation, let’s begin with a concrete example using embedding matrix $E \in \mathbb{R}^{V \times h}$ and a sequence of three tokens $[2, 5, 2]$. During the forward pass, we construct matrix $X \in \mathbb{R}^{3 \times h}$ by looking up these tokens in $E$:

\[X = \begin{bmatrix} E_{21} & E_{22} & \cdots & E_{2h} \\ E_{51} & E_{52} & \cdots & E_{5h} \\ E_{21} & E_{22} & \cdots & E_{2h} \end{bmatrix}\]

During the backward pass, we receive incoming gradients of the same shape as $X$:

\[\frac{\partial L}{\partial X} = \begin{bmatrix} \frac{\partial L}{\partial X_{11}} & \frac{\partial L}{\partial X_{12}} & \cdots & \frac{\partial L}{\partial X_{1h}} \\ \frac{\partial L}{\partial X_{21}} & \frac{\partial L}{\partial X_{22}} & \cdots & \frac{\partial L}{\partial X_{2h}} \\ \frac{\partial L}{\partial X_{31}} & \frac{\partial L}{\partial X_{32}} & \cdots & \frac{\partial L}{\partial X_{3h}} \end{bmatrix}\]

To compute the gradient of the loss with respect to the embedding matrix $E$, we need to pass these gradients back to the rows they came from. Since row 2 of $E$ was used twice (for positions 1 and 3), its gradient accumulates both $\text{row}_1(\frac{\partial L}{\partial X})$ and $\text{row}_3(\frac{\partial L}{\partial X})$. Row 5 was used once, so it receives $\text{row}_2(\frac{\partial L}{\partial X})$. All other rows in $E$ weren’t used, so they receive zero gradient:

\[\frac{\partial L}{\partial E} = \begin{bmatrix} 0 & 0 & \cdots & 0 \\ \frac{\partial L}{\partial X_{11}} + \frac{\partial L}{\partial X_{31}} & \frac{\partial L}{\partial X_{12}} + \frac{\partial L}{\partial X_{32}} & \cdots & \frac{\partial L}{\partial X_{1h}} + \frac{\partial L}{\partial X_{3h}} \\ 0 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 \\ \frac{\partial L}{\partial X_{21}} & \frac{\partial L}{\partial X_{22}} & \cdots & \frac{\partial L}{\partial X_{2h}} \\ \vdots & \vdots & \ddots & \vdots \end{bmatrix}\]

This generalizes to any sequence of tokens $[t_1, t_2, …, t_s]$, where the forward pass constructs matrix $X \in \mathbb{R}^{s \times h}$:

\[X = \begin{bmatrix} E_{t_11} & E_{t_12} & \cdots & E_{t_1h} \\ E_{t_21} & E_{t_22} & \cdots & E_{t_2h} \\ E_{t_31} & E_{t_32} & \cdots & E_{t_3h} \\ \vdots & \vdots & \ddots & \vdots \\ E_{t_s1} & E_{t_s2} & \cdots & E_{t_sh} \end{bmatrix}\]

During the backward pass, the $i$-th row in $\frac{\partial L}{\partial E}$ accumulates gradients from all positions where that token appeared in the sequence:

\[\text{row}_i(\frac{\partial L}{\partial E}) = \sum_{t_j = i} \text{row}_j(\frac{\partial L}{\partial X})\]

Position Embeddings

Unlike token embeddings where indices can repeat and thus rows of $E$ may be used multiple times, each position has a unique index and thus each row of the position embedding matrix is used exactly once in a sequence. Thus, the upstream gradients $\frac{\partial L}{\partial P_{emb}} \in \mathbb{R}^{s \times h}$ map directly to $\frac{\partial L}{\partial P} \in \mathbb{R}^{s \times h}$.

Batched sequences

When processing a batch of sequences, the embeddings sum changes to $Y = X + P_{emb}$ where $X \in \mathbb{R}^{b \times s \times h}$ but $P_{emb}$ remains $\in \mathbb{R}^{s \times h}$. This involves broadcasting $P_{emb}$ across the batch dimension. Thus, when computing $\frac{\partial L}{\partial P_{emb}}$, we must reduce the upstream gradient $\frac{\partial L}{\partial Y}$ over the batch dimension, requiring $(b-1)sh$ additions.

For token embeddings, we now accumulate gradients over all occurrences of each token across the entire batch rather than just within a single sequence. The backward pass consists primarily of indexing operations with $O(h)$ additions. This cost is negligible compared to other operations in the network.

Negative Log Likelihood Loss

PyTorch implements F.cross_entropy(logits, targets) as a composition of logprobs = F.log_softmax(logits) and F.nll_loss(logprobs, targets).

The forward pass for negative log-likelihood takes a log probabilities tensor of shape $(b, s, V)$ and a target labels tensor of shape $(b, s)$. At each position, it grabs the log-probability corresponding to the target label and averages these values. The backward pass constructs a sparse gradient tensor with zeros everywhere except at the target indices, where it’s $-\frac{1}{b \times s}$. Since this boils down to indexing and assignment, no FLOPs are performed.

Final Tally

Excluding token embedding gradients, which depend on input token frequencies, and position embedding gradients, which contribute negligibly to the total computation, the FLOPs for a single transformer layer can be expressed as:

\[\begin{align*} \text{Layer FLOPs} &= \text{pre_attn_ln} + \text{qkv_proj} + \text{query@key} + \text{softmax} + \text{scaling} \\ &\quad\; + \; \text{attn@value} + \text{attn_out} + \text{pre_mlp_ln} + \text{mlp_up} + \text{gelu} + \text{mlp_down} \\ &= 11bsh + 12bsh^2 + 4bn_hs^2d + 4bn_hs^2 + bn_hs^2 \\ &\quad\; + \; 4bn_hs^2d + 4bsh^2 + 11bsh + 16bsh^2 + 76bsh + 16bsh^2 \\ &= 48bsh^2 + 8bn_hs^2d + 5bn_hs^2 + 98bsh \end{align*}\]

The total backward pass FLOPs for the full network is then:

\[\begin{align*} \text{Total FLOPs} &= n_{layer} \times \text{Layer FLOPs} + \text{final_ln} + \text{lm_head} + \text{cross_entropy} \\ &= n_{layer}(48bsh^2 + 8bn_hs^2d + 5bn_hs^2 + 98bsh) + 11bsh + 4bshV \end{align*}\]