The softmax function is ubiquitous in neural networks, especially in transformers where it plays an important role in the attention mechanism. The mathematical definition is simple:
However, the naive implementation has severe numerical stability issues.
The problem: numerical overflow and underflow
Consider what happens when we try to compute softmax on large values:
import numpy as npdef naive_softmax(x): exp_x = np.exp(x)return exp_x / np.sum(exp_x)# This causes overflowlarge_values = np.array([1000, 1001, 1002])print(f"exp(1000) = {np.exp(1000)}") # infprint(f"naive_softmax result: {naive_softmax(large_values)}") # nan, nan, nan
exp(1000) = inf
naive_softmax result: [nan nan nan]
/var/folders/0s/q5mn7k7n46v1x5c9klg2ck_w0000gn/T/ipykernel_2643/3427519139.py:9: RuntimeWarning:
overflow encountered in exp
/var/folders/0s/q5mn7k7n46v1x5c9klg2ck_w0000gn/T/ipykernel_2643/3427519139.py:4: RuntimeWarning:
overflow encountered in exp
/var/folders/0s/q5mn7k7n46v1x5c9klg2ck_w0000gn/T/ipykernel_2643/3427519139.py:5: RuntimeWarning:
invalid value encountered in divide
The exponential function grows extremely fast. For values around 1000, exp(x) exceeds the maximum representable floating-point number, resulting in infinity. The division then becomes inf/inf = nan.
The opposite problem occurs with very negative values:
# This causes underflowsmall_values = np.array([-1000, -1001, -1002])print(f"exp(-1000) = {np.exp(-1000)}") # 0.0print(f"naive_softmax result: {naive_softmax(small_values)}") # nan, nan, nan
exp(-1000) = 0.0
naive_softmax result: [nan nan nan]
/var/folders/0s/q5mn7k7n46v1x5c9klg2ck_w0000gn/T/ipykernel_2643/3427519139.py:5: RuntimeWarning:
invalid value encountered in divide
Here exp(x) underflows to zero, and we get 0/0 = nan.
The constant trick
The solution exploits a mathematical identity. We can factor out any constant c from the softmax computation:
So we can calculate softmax as: \[
softmax(x_i) = \frac{e^{x_i - c}}{\sum_j e^{x_j - c}}
\]
Choosing \(c = \max(x)\) ensures the largest exponentiated value is exp(0) = 1, preventing overflow. Any values that end up small enough to underflow to zero after subtraction represent genuinely negligible probabilities, so the result remains correct.
def stable_softmax(x): c = np.max(x) exp_shifted = np.exp(x - c)return exp_shifted / np.sum(exp_shifted)# Now both cases work perfectlyprint(f"Large values: {stable_softmax(large_values)}")print(f"Small values: {stable_softmax(small_values)}")
Large values: [0.09003057 0.24472847 0.66524096]
Small values: [0.66524096 0.24472847 0.09003057]
The actual PyTorch implementation
You can see the source code for the softmax implementation in PyTorch. The core logic is essentially the same as described above, but it’s implemented in C++ for performance reasons. The code iterates over the input data, computes the maximum value, and then applies the stable softmax formula.
Here’s a simplified, easy-to-understand version of the core logic in C++:
void host_softmax(Tensor output,const Tensor& input,constint64_t dim){// initialize max_input to the first element of the input tensorscalar_t max_input = input_data[0];// iterate over the input data to find the maximum valuefor(int64_t d =1; d < dim_size; d++) max_input =std::max(max_input, input_data[d * dim_stride]);// initialize tmpsum as zero to accumulate the sum of exponentialsacc_type<scalar_t,false> tmpsum =0;for(int64_t d =0; d < dim_size; d++){// compute the exponential of the shifted input valuescalar_t z =std::exp(input_data[d * dim_stride]- max_input);// store the result in the output tensor to serve as the// numerator of the softmax formula output_data[d * dim_stride]= z;// accumulate the sum of the exponentials to serve as the // denominator of the softmax formula tmpsum += z;}// we need to divide by the sum of exponentials,// so we compute the reciprocal of tmpsum tmpsum =1/ tmpsum;// finally, we multiply each element in the output tensor by// tmpsum to get the final softmax valuesfor(int64_t d =0; d < dim_size; d++) output_data[d * dim_stride]*= tmpsum;}
Why this matters for transformers
The stable softmax implementation ensures that regardless of the magnitude of attention scores, the computation remains numerically sound. This is critical for training stability as a single NaN or large rounding errors in the attention matrix can propagate and corrupt the entire model.