Transformers in a nutshell
Transformers in a nutshell (pre-normalization, decoder-only architecture):
Embedding Layer
\[ x = S \cdot E \]
where \(E\) is the embedding matrix and \(S\) is a matrix of canonical (standard basis \([0, ..., 1, ..., 0]\)) vectors that selects the embeddings for the input tokens
Attention Block
\[ \text{residual} = x \]
\[ x = RMSNorm(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \cdot g \]
Where \(g\) is a learnable scaling vector and \(\epsilon\) is a small constant to avoid division by zero.
\[ Q = x \cdot W_q \quad K = x \cdot W_k \quad V = x \cdot W_v \]
In actual code implementations, these multiplications are fused into a single matrix multiplication for efficiency.
\[ Q,K = RoPE(Q,K) \]
\[ x = softmax\left(\frac{QK^T}{\sqrt{d_k}} + M \right) \cdot V \]
Where \(M\) is the attention mask (causal mask for decoder-only models), it contains \(-\infty\) in the positions above the main diagonal and \(0\) elsewhere.
In practice you never compute the full \(QK^T\) matrix (that’s huge for long sequences). You use flash attention.
\[ x = x \cdot W_o \]
\[ x = \text{residual} + x \]
Feed-Forward Network (FFN)
\[ \text{residual} = x \]
\[ x = RMSNorm(x) \]
\[ Gate = x \cdot W_{gate} \quad Up = x \cdot W_{up} \]
\[ x = (f(Gate) \circ Up) \cdot W_{down} \]
\[ x = \text{residual} + x \]
Decoder Layer
Final normalization \[ x = RMSNorm(x) \]
\[ Probs = softmax(x \cdot L) \] Where \(L\) is the linear head classifier (could be tied or not with the embedding matrix \(E\)).
\[ Cost = CrossEntropy(Probs, Target) \]