Flash Attention Fixed Everything

We spent two days debugging NaN in a transformer trained from scratch in Nim with CUDA. Every fix failed. Lower learning rate — delayed the NaN. Gradient clipping — NaN appeared in the forward pass instead. Dropout, QK-Norm, weight decay, learnable gamma, different init std — nothing worked. The model would train beautifully to loss 4.6, then explode.

The fix was one CUDA kernel.


The Bug

The attention mechanism computes softmax(Q·K^T / sqrt(d)) @ V. In float32, exp(x) overflows to infinity when x > 88. Our attention computed the full score matrix, applied a causal mask, then ran softmax in a separate kernel. The softmax subtracts the max before computing exp() — this is the standard numerical stability trick and it should prevent overflow.

It didn't. Here's why.

As the model trained, it learned to sharpen attention — making some Q·K^T scores very large to focus on specific positions. After scaling by 1/sqrt(d), the largest scores approached 80-90. The softmax handled these fine (max subtraction works). But the gradients through the attention became large, slowly pushing the projection weights further. Eventually some weight grew enough that a specific Q·K^T product exceeded 88 after scaling. exp(89) = inf. inf × V = inf. The inf propagated through the residual stream, corrupted the RMSNorm, and the entire forward pass produced NaN.

We traced it precisely:

NaN at step 1558
  INF: L1.attnOut[7424]=4.221511e+34
  NaN: L1.xInput2[7424]
  NaN: L1.xNorm2[7424]
  NaN: L1.fc1Out[29696]
  ...everything downstream...

The first infinity appeared in the attention output of layer 1. All Q, K, V values were finite. The infinity was born inside the softmax(scores) @ V computation.

We tried clamping the scores before softmax. It prevented the forward NaN but introduced a gradient inconsistency — the backward didn't know about the clamp, producing wrong gradients that slowly corrupted the weights until they themselves became NaN.


The Fix

Flash attention computes the entire attention operation — score, scale, mask, softmax, weighted sum — in a single fused kernel using online softmax. The key insight: instead of computing all scores, finding the max, then computing exp(score - max), it processes one key-value pair at a time and maintains a running maximum.

m = -inf          # running max score
l = 0             # running sum of exp
acc = 0           # running output (unnormalized)

for each key-value pair (k, v):
    score = q · k * scale
    m_new = max(m, score)
    exp_old = exp(m - m_new)      # rescale old accumulator
    exp_new = exp(score - m_new)  # always ≤ 0, never overflows
    l_new = l * exp_old + exp_new
    acc = acc * (l * exp_old / l_new) + v * (exp_new / l_new)
    m = m_new
    l = l_new

The critical line: exp(score - m_new). Since m_new is the maximum score seen so far, score - m_new ≤ 0. The exponent is always non-positive. exp(0) = 1. exp(-anything) < 1. Overflow is mathematically impossible.

The full kernel is 40 lines of CUDA:

__global__ void k_flash_attn_fwd(
    const float *Q, const float *K, const float *V, float *O,
    int S, int hd, float scale, int causal
) {
    int row = blockIdx.x;
    int j = threadIdx.x;
    if (row >= S || j >= hd) return;

    const float *q = Q + row * hd;
    float m = -FLT_MAX, l = 0.0f, acc = 0.0f;
    int max_col = causal ? (row + 1) : S;

    for (int col = 0; col < max_col; col++) {
        float score = 0.0f;
        for (int d = 0; d < hd; d++)
            score += q[d] * K[col * hd + d];
        score *= scale;

        float m_new = fmaxf(m, score);
        float exp_old = expf(m - m_new);
        float exp_new = expf(score - m_new);
        float l_new = l * exp_old + exp_new;

        acc = acc * (l * exp_old / l_new)
            + V[col * hd + j] * (exp_new / l_new);
        m = m_new;
        l = l_new;
    }
    O[row * hd + j] = acc;
}

The Result

Before flash attention: NaN at every learning rate above 0.00003 after 1000-2000 steps. Loss plateaued at 7.0 because the learning rate was too low to make progress.

After flash attention: stable training at lr=0.0003 (10x higher). Loss dropped from 7.7 to 4.6 in 8000 steps with zero NaN. The model is still training.

opt  1000 | loss 7.11 | lr 0.000150
opt  2000 | loss 6.78 | lr 0.000300
opt  4000 | loss 5.00 | lr 0.000300
opt  8000 | loss 4.62 | lr 0.000299

28 million parameters. 8 layers. 512 dimensions. Training from scratch on 37,000 conversations. Written in Nim. Direct CUDA. No PyTorch.


The Lesson

We spent two days trying every hyperparameter trick: lower learning rate, gradient clipping, dropout, weight decay, QK-Norm, different init std, different beta values, gradient accumulation. None of it mattered because the problem was architectural — the standard separate-kernel attention computation is numerically fragile in float32.

Flash attention isn't just faster (it is — O(n) memory instead of O(n²)). It's correct. The online softmax makes overflow impossible by construction. Every serious training framework uses it or something equivalent. We should have implemented it first instead of last.

The code is 40 lines of CUDA. The bug it fixed took 2000 lines of debugging to find.


See also: NimLLM, 103M Parameters on a 3060.

Co-authored with Claude.