Part 5: Training

We've built a working transformer. It can predict the next token, generate text, and learn patterns from data. But it's not a ChatGPT yet. It processes one sequence at a time and knows nothing about conversations.

This part fixes all of that. We'll implement batch training, learn how pre-training differs from fine-tuning, teach the model to chat using SFT, and add the optimization tricks that make training actually work at scale.

By the end, you'll have a model that responds to prompts like a chatbot, not just a text completion engine.

Training with Batches

The training loop from Part 4 processes one sequence at a time:

for (let start = 0; start < tokens.length - sequenceLength - 1; start += 8) {
    const inputTokens = tokens.slice(start, start + sequenceLength);
    const targetTokens = tokens.slice(start + 1, start + sequenceLength + 1);
    const loss = model.train(inputTokens, targetTokens, learningRate);
}

This works, but it's inefficient and produces noisy gradients. A single example might push weights in a weird direction; averaging across many examples gives a cleaner signal. Batch training processes multiple sequences simultaneously. Instead of shape [seqLength], the inputs become [batchSize][seqLength]. Every layer must handle this extra dimension.

Single sequence:     [16]              -> "The cat sat on..."
Batch of 4:          [4][16]           -> 4 different sequences at once
After embedding:     [4][16][64]       -> 4 sequences, 16 positions, 64-dim embeddings

Batched Embedding Layer

The embedding layer needs minimal changes—we just add an outer loop over the batch:

class EmbeddingLayer {
    weights = null;
    vocabSize = 0;
    embeddingDim = 0;
    cachedInputTokens = null;

    constructor(vocabSize, embeddingDim) {
        this.vocabSize = vocabSize;
        this.embeddingDim = embeddingDim;

        const scale = Math.sqrt(1.0 / embeddingDim);
        this.weights = new Array(vocabSize);
        for (let i = 0; i < vocabSize; i++) {
            this.weights[i] = new Array(embeddingDim);
            for (let j = 0; j < embeddingDim; j++) {
                this.weights[i][j] = (Math.random() * 2 - 1) * scale;
            }
        }
    }

    forward(inputTokens) {
        // inputTokens: [batchSize][seqLength]
        const batchSize = inputTokens.length;
        const seqLength = inputTokens[0].length;

        this.cachedInputTokens = inputTokens;

        // Output: [batchSize][seqLength][embeddingDim]
        const output = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            output[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                const tokenId = inputTokens[b][t];
                output[b][t] = new Array(this.embeddingDim);
                for (let d = 0; d < this.embeddingDim; d++) {
                    output[b][t][d] = this.weights[tokenId][d];
                }
            }
        }

        return output;
    }

    backward(outputGradients, learningRate) {
        // outputGradients: [batchSize][seqLength][embeddingDim]
        const batchSize = this.cachedInputTokens.length;
        const seqLength = this.cachedInputTokens[0].length;

        for (let b = 0; b < batchSize; b++) {
            for (let t = 0; t < seqLength; t++) {
                const tokenId = this.cachedInputTokens[b][t];
                for (let d = 0; d < this.embeddingDim; d++) {
                    this.weights[tokenId][d] -= learningRate * outputGradients[b][t][d];
                }
            }
        }

        return null;
    }
}

The forward pass looks up embeddings for each token in each sequence. The backward pass accumulates gradients—if the same token appears in multiple sequences (or multiple positions), its embedding gets updated multiple times.

Batched Positional Embedding Layer

Same pattern, wrap everything in a batch loop:

class PositionalEmbeddingLayer {
    tokenEmbedding = null;
    positionWeights = null;
    maxSequenceLength = 0;
    embeddingDim = 0;
    cachedBatchSize = 0;
    cachedSeqLength = 0;

    constructor(vocabSize, embeddingDim, maxSequenceLength) {
        this.embeddingDim = embeddingDim;
        this.maxSequenceLength = maxSequenceLength;

        this.tokenEmbedding = new EmbeddingLayer(vocabSize, embeddingDim);

        const scale = Math.sqrt(1.0 / embeddingDim);
        this.positionWeights = new Array(maxSequenceLength);
        for (let pos = 0; pos < maxSequenceLength; pos++) {
            this.positionWeights[pos] = new Array(embeddingDim);
            for (let d = 0; d < embeddingDim; d++) {
                this.positionWeights[pos][d] = (Math.random() * 2 - 1) * scale;
            }
        }
    }

    forward(inputTokens) {
        // inputTokens: [batchSize][seqLength]
        const tokenEmbeddings = this.tokenEmbedding.forward(inputTokens);

        const batchSize = inputTokens.length;
        const seqLength = inputTokens[0].length;
        this.cachedBatchSize = batchSize;
        this.cachedSeqLength = seqLength;

        // Add position embeddings (same positions for all sequences in batch)
        const output = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            output[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                output[b][t] = new Array(this.embeddingDim);
                for (let d = 0; d < this.embeddingDim; d++) {
                    output[b][t][d] = tokenEmbeddings[b][t][d] + this.positionWeights[t][d];
                }
            }
        }

        return output;
    }

    backward(outputGradients, learningRate) {
        // Position gradients: sum across batch (same position used by all sequences)
        for (let t = 0; t < this.cachedSeqLength; t++) {
            for (let d = 0; d < this.embeddingDim; d++) {
                let grad = 0;
                for (let b = 0; b < this.cachedBatchSize; b++) {
                    grad += outputGradients[b][t][d];
                }
                this.positionWeights[t][d] -= learningRate * grad;
            }
        }

        this.tokenEmbedding.backward(outputGradients, learningRate);
        return null;
    }
}

Notice how position gradients sum across the batch. Position 0's embedding is used by every sequence, so its gradient accumulates from all of them.

Batched Layer Normalization

Layer normalization computes statistics per position, per sequence. With batches, we normalize each [embeddingDim] vector independently:

class LayerNormalization {
    gamma = null;
    beta = null;
    featureSize = 0;
    epsilon = 1e-5;

    cachedInputs = null;
    cachedMean = null;
    cachedVariance = null;
    cachedNormalized = null;

    constructor(featureSize) {
        this.featureSize = featureSize;

        this.gamma = new Array(featureSize);
        this.beta = new Array(featureSize);
        for (let i = 0; i < featureSize; i++) {
            this.gamma[i] = 1.0;
            this.beta[i] = 0.0;
        }
    }

    forward(inputs) {
        // inputs: [batchSize][seqLength][featureSize]
        const batchSize = inputs.length;
        const seqLength = inputs[0].length;

        this.cachedInputs = inputs;
        this.cachedMean = new Array(batchSize);
        this.cachedVariance = new Array(batchSize);
        this.cachedNormalized = new Array(batchSize);

        const output = new Array(batchSize);

        for (let b = 0; b < batchSize; b++) {
            this.cachedMean[b] = new Array(seqLength);
            this.cachedVariance[b] = new Array(seqLength);
            this.cachedNormalized[b] = new Array(seqLength);
            output[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                // Compute mean
                let mean = 0;
                for (let i = 0; i < this.featureSize; i++) {
                    mean += inputs[b][t][i];
                }
                mean /= this.featureSize;
                this.cachedMean[b][t] = mean;

                // Compute variance
                let variance = 0;
                for (let i = 0; i < this.featureSize; i++) {
                    const diff = inputs[b][t][i] - mean;
                    variance += diff * diff;
                }
                variance /= this.featureSize;
                this.cachedVariance[b][t] = variance;

                // Normalize and apply gamma/beta
                const stdInv = 1.0 / Math.sqrt(variance + this.epsilon);
                this.cachedNormalized[b][t] = new Array(this.featureSize);
                output[b][t] = new Array(this.featureSize);

                for (let i = 0; i < this.featureSize; i++) {
                    const normalized = (inputs[b][t][i] - mean) * stdInv;
                    this.cachedNormalized[b][t][i] = normalized;
                    output[b][t][i] = this.gamma[i] * normalized + this.beta[i];
                }
            }
        }

        return output;
    }

    backward(outputGradients, learningRate) {
        const batchSize = outputGradients.length;
        const seqLength = outputGradients[0].length;

        const inputGradients = new Array(batchSize);

        // Accumulate gamma/beta gradients across batch and sequence
        const gammaGrad = new Array(this.featureSize).fill(0);
        const betaGrad = new Array(this.featureSize).fill(0);

        for (let b = 0; b < batchSize; b++) {
            inputGradients[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                const mean = this.cachedMean[b][t];
                const variance = this.cachedVariance[b][t];
                const stdInv = 1.0 / Math.sqrt(variance + this.epsilon);

                // Accumulate parameter gradients
                for (let i = 0; i < this.featureSize; i++) {
                    gammaGrad[i] += outputGradients[b][t][i] * this.cachedNormalized[b][t][i];
                    betaGrad[i] += outputGradients[b][t][i];
                }

                // Compute input gradients
                inputGradients[b][t] = new Array(this.featureSize);

                // Gradient through normalization
                let dNormSum = 0;
                let dVarSum = 0;

                for (let i = 0; i < this.featureSize; i++) {
                    const dNorm = outputGradients[b][t][i] * this.gamma[i];
                    dNormSum += dNorm;
                    dVarSum += dNorm * (this.cachedInputs[b][t][i] - mean);
                }

                const dVar = dVarSum * -0.5 * Math.pow(variance + this.epsilon, -1.5);
                const dMean = -stdInv * dNormSum + dVar * -2.0 / this.featureSize *
                    (this.cachedInputs[b][t].reduce((a, v) => a + v - mean, 0));

                for (let i = 0; i < this.featureSize; i++) {
                    const dNorm = outputGradients[b][t][i] * this.gamma[i];
                    inputGradients[b][t][i] = dNorm * stdInv +
                        dVar * 2.0 * (this.cachedInputs[b][t][i] - mean) / this.featureSize +
                        dMean / this.featureSize;
                }
            }
        }

        // Update parameters
        for (let i = 0; i < this.featureSize; i++) {
            this.gamma[i] -= learningRate * gammaGrad[i];
            this.beta[i] -= learningRate * betaGrad[i];
        }

        return inputGradients;
    }
}

Gamma and beta are shared across all positions and all sequences, so their gradients accumulate from everywhere.

Batched Attention

Attention is where batching pays off most. Each sequence in the batch attends independently, there's no cross-sequence attention. We just run the same attention computation for each batch element:

class MultiHeadAttention {
    numHeads = 0;
    headDim = 0;
    embeddingDim = 0;

    queryWeights = null;
    keyWeights = null;
    valueWeights = null;
    outputWeights = null;

    // Cached for backward pass
    cachedInputs = null;
    cachedQueries = null;
    cachedKeys = null;
    cachedValues = null;
    cachedAttentionWeights = null;

    constructor(embeddingDim, numHeads) {
        this.embeddingDim = embeddingDim;
        this.numHeads = numHeads;
        this.headDim = embeddingDim / numHeads;

        const scale = Math.sqrt(2.0 / embeddingDim);

        this.queryWeights = this.#initializeWeights(embeddingDim, embeddingDim, scale);
        this.keyWeights = this.#initializeWeights(embeddingDim, embeddingDim, scale);
        this.valueWeights = this.#initializeWeights(embeddingDim, embeddingDim, scale);
        this.outputWeights = this.#initializeWeights(embeddingDim, embeddingDim, scale);
    }

    #initializeWeights(inputSize, outputSize, scale) {
        const weights = new Array(inputSize);
        for (let i = 0; i < inputSize; i++) {
            weights[i] = new Array(outputSize);
            for (let j = 0; j < outputSize; j++) {
                weights[i][j] = (Math.random() * 2 - 1) * scale;
            }
        }
        return weights;
    }

    forward(inputs) {
        // inputs: [batchSize][seqLength][embeddingDim]
        const batchSize = inputs.length;
        const seqLength = inputs[0].length;

        this.cachedInputs = inputs;

        // Project to Q, K, V for entire batch
        const queries = this.#batchMatmul(inputs, this.queryWeights);
        const keys = this.#batchMatmul(inputs, this.keyWeights);
        const values = this.#batchMatmul(inputs, this.valueWeights);

        this.cachedQueries = queries;
        this.cachedKeys = keys;
        this.cachedValues = values;

        // Compute attention for each batch element
        const scale = 1.0 / Math.sqrt(this.headDim);
        this.cachedAttentionWeights = new Array(batchSize);
        const attentionOutputs = new Array(batchSize);

        for (let b = 0; b < batchSize; b++) {
            this.cachedAttentionWeights[b] = new Array(this.numHeads);
            attentionOutputs[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                attentionOutputs[b][t] = new Array(this.embeddingDim).fill(0);
            }

            // Process each head
            for (let h = 0; h < this.numHeads; h++) {
                const headStart = h * this.headDim;
                this.cachedAttentionWeights[b][h] = new Array(seqLength);

                for (let t = 0; t < seqLength; t++) {
                    // Compute attention scores for position t
                    const scores = new Array(seqLength);

                    for (let s = 0; s <= t; s++) {  // Causal: only attend to past
                        let score = 0;
                        for (let d = 0; d < this.headDim; d++) {
                            score += queries[b][t][headStart + d] * keys[b][s][headStart + d];
                        }
                        scores[s] = score * scale;
                    }

                    // Mask future positions
                    for (let s = t + 1; s < seqLength; s++) {
                        scores[s] = -Infinity;
                    }

                    // Softmax
                    const maxScore = Math.max(...scores.filter(s => s !== -Infinity));
                    let sumExp = 0;
                    const expScores = new Array(seqLength);
                    for (let s = 0; s < seqLength; s++) {
                        expScores[s] = scores[s] === -Infinity ? 0 : Math.exp(scores[s] - maxScore);
                        sumExp += expScores[s];
                    }

                    const attentionWeights = new Array(seqLength);
                    for (let s = 0; s < seqLength; s++) {
                        attentionWeights[s] = expScores[s] / sumExp;
                    }
                    this.cachedAttentionWeights[b][h][t] = attentionWeights;

                    // Apply attention to values
                    for (let s = 0; s < seqLength; s++) {
                        for (let d = 0; d < this.headDim; d++) {
                            attentionOutputs[b][t][headStart + d] +=
                                attentionWeights[s] * values[b][s][headStart + d];
                        }
                    }
                }
            }
        }

        // Output projection
        const output = this.#batchMatmul(attentionOutputs, this.outputWeights);
        return output;
    }

    #batchMatmul(inputs, weights) {
        // inputs: [batchSize][seqLength][inputDim]
        // weights: [inputDim][outputDim]
        // output: [batchSize][seqLength][outputDim]
        const batchSize = inputs.length;
        const seqLength = inputs[0].length;
        const inputDim = weights.length;
        const outputDim = weights[0].length;

        const output = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            output[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                output[b][t] = new Array(outputDim).fill(0);
                for (let i = 0; i < inputDim; i++) {
                    for (let j = 0; j < outputDim; j++) {
                        output[b][t][j] += inputs[b][t][i] * weights[i][j];
                    }
                }
            }
        }
        return output;
    }

    backward(outputGradients, learningRate) {
        const batchSize = outputGradients.length;
        const seqLength = outputGradients[0].length;

        // Gradient through output projection
        const attentionGradients = this.#batchMatmulBackward(
            outputGradients, this.cachedInputs, this.outputWeights, learningRate, 'output'
        );

        // Gradient through attention mechanism
        const queryGrads = new Array(batchSize);
        const keyGrads = new Array(batchSize);
        const valueGrads = new Array(batchSize);

        for (let b = 0; b < batchSize; b++) {
            queryGrads[b] = new Array(seqLength);
            keyGrads[b] = new Array(seqLength);
            valueGrads[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                queryGrads[b][t] = new Array(this.embeddingDim).fill(0);
                keyGrads[b][t] = new Array(this.embeddingDim).fill(0);
                valueGrads[b][t] = new Array(this.embeddingDim).fill(0);
            }

            const scale = 1.0 / Math.sqrt(this.headDim);

            for (let h = 0; h < this.numHeads; h++) {
                const headStart = h * this.headDim;

                for (let t = 0; t < seqLength; t++) {
                    const attnWeights = this.cachedAttentionWeights[b][h][t];

                    // Gradient w.r.t. values
                    for (let s = 0; s <= t; s++) {
                        for (let d = 0; d < this.headDim; d++) {
                            valueGrads[b][s][headStart + d] +=
                                attnWeights[s] * attentionGradients[b][t][headStart + d];
                        }
                    }

                    // Gradient w.r.t. attention weights
                    const dAttnWeights = new Array(seqLength).fill(0);
                    for (let s = 0; s <= t; s++) {
                        for (let d = 0; d < this.headDim; d++) {
                            dAttnWeights[s] += attentionGradients[b][t][headStart + d] *
                                this.cachedValues[b][s][headStart + d];
                        }
                    }

                    // Gradient through softmax
                    let dotProduct = 0;
                    for (let s = 0; s <= t; s++) {
                        dotProduct += dAttnWeights[s] * attnWeights[s];
                    }

                    const dScores = new Array(seqLength).fill(0);
                    for (let s = 0; s <= t; s++) {
                        dScores[s] = attnWeights[s] * (dAttnWeights[s] - dotProduct) * scale;
                    }

                    // Gradient w.r.t. queries and keys
                    for (let s = 0; s <= t; s++) {
                        for (let d = 0; d < this.headDim; d++) {
                            queryGrads[b][t][headStart + d] +=
                                dScores[s] * this.cachedKeys[b][s][headStart + d];
                            keyGrads[b][s][headStart + d] +=
                                dScores[s] * this.cachedQueries[b][t][headStart + d];
                        }
                    }
                }
            }
        }

        // Gradient through Q, K, V projections
        const inputGrads1 = this.#batchMatmulBackward(
            queryGrads, this.cachedInputs, this.queryWeights, learningRate, 'query'
        );
        const inputGrads2 = this.#batchMatmulBackward(
            keyGrads, this.cachedInputs, this.keyWeights, learningRate, 'key'
        );
        const inputGrads3 = this.#batchMatmulBackward(
            valueGrads, this.cachedInputs, this.valueWeights, learningRate, 'value'
        );

        // Sum input gradients from all three projections
        const inputGradients = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            inputGradients[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                inputGradients[b][t] = new Array(this.embeddingDim);
                for (let d = 0; d < this.embeddingDim; d++) {
                    inputGradients[b][t][d] = inputGrads1[b][t][d] +
                        inputGrads2[b][t][d] + inputGrads3[b][t][d];
                }
            }
        }

        return inputGradients;
    }

    #batchMatmulBackward(outputGrads, inputs, weights, learningRate, which) {
        const batchSize = outputGrads.length;
        const seqLength = outputGrads[0].length;
        const inputDim = weights.length;
        const outputDim = weights[0].length;

        // Gradient w.r.t. inputs
        const inputGrads = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            inputGrads[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                inputGrads[b][t] = new Array(inputDim).fill(0);
                for (let i = 0; i < inputDim; i++) {
                    for (let j = 0; j < outputDim; j++) {
                        inputGrads[b][t][i] += outputGrads[b][t][j] * weights[i][j];
                    }
                }
            }
        }

        // Gradient w.r.t. weights (accumulate across batch)
        for (let i = 0; i < inputDim; i++) {
            for (let j = 0; j < outputDim; j++) {
                let grad = 0;
                for (let b = 0; b < batchSize; b++) {
                    for (let t = 0; t < seqLength; t++) {
                        grad += inputs[b][t][i] * outputGrads[b][t][j];
                    }
                }
                weights[i][j] -= learningRate * grad;
            }
        }

        return inputGrads;
    }
}

The attention mechanism is fundamentally the same. Each position attends to all previous positions (causal masking). We just do it independently for each sequence in the batch.

Batched MLP

The MLP block applies the same feed-forward network to each position independently:

class MLPBlock {
    weights1 = null;
    bias1 = null;
    weights2 = null;
    bias2 = null;

    inputDim = 0;
    hiddenDim = 0;

    cachedInputs = null;
    cachedHidden = null;

    constructor(embeddingDim) {
        this.inputDim = embeddingDim;
        this.hiddenDim = embeddingDim * 4;

        const scale1 = Math.sqrt(2.0 / embeddingDim);
        const scale2 = Math.sqrt(2.0 / this.hiddenDim);

        this.weights1 = this.#initWeights(embeddingDim, this.hiddenDim, scale1);
        this.bias1 = new Array(this.hiddenDim).fill(0);
        this.weights2 = this.#initWeights(this.hiddenDim, embeddingDim, scale2);
        this.bias2 = new Array(embeddingDim).fill(0);
    }

    #initWeights(inputSize, outputSize, scale) {
        const weights = new Array(inputSize);
        for (let i = 0; i < inputSize; i++) {
            weights[i] = new Array(outputSize);
            for (let j = 0; j < outputSize; j++) {
                weights[i][j] = (Math.random() * 2 - 1) * scale;
            }
        }
        return weights;
    }

    forward(inputs) {
        // inputs: [batchSize][seqLength][embeddingDim]
        const batchSize = inputs.length;
        const seqLength = inputs[0].length;

        this.cachedInputs = inputs;
        this.cachedHidden = new Array(batchSize);
        const output = new Array(batchSize);

        for (let b = 0; b < batchSize; b++) {
            this.cachedHidden[b] = new Array(seqLength);
            output[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                // First linear layer
                const hidden = new Array(this.hiddenDim);
                for (let h = 0; h < this.hiddenDim; h++) {
                    let sum = this.bias1[h];
                    for (let i = 0; i < this.inputDim; i++) {
                        sum += inputs[b][t][i] * this.weights1[i][h];
                    }
                    // GELU activation
                    hidden[h] = sum * 0.5 * (1 + Math.tanh(
                        Math.sqrt(2 / Math.PI) * (sum + 0.044715 * sum * sum * sum)
                    ));
                }
                this.cachedHidden[b][t] = hidden;

                // Second linear layer
                output[b][t] = new Array(this.inputDim);
                for (let o = 0; o < this.inputDim; o++) {
                    let sum = this.bias2[o];
                    for (let h = 0; h < this.hiddenDim; h++) {
                        sum += hidden[h] * this.weights2[h][o];
                    }
                    output[b][t][o] = sum;
                }
            }
        }

        return output;
    }

    backward(outputGradients, learningRate) {
        const batchSize = outputGradients.length;
        const seqLength = outputGradients[0].length;

        const inputGradients = new Array(batchSize);

        // Accumulate weight/bias gradients
        const weights2Grad = new Array(this.hiddenDim);
        for (let h = 0; h < this.hiddenDim; h++) {
            weights2Grad[h] = new Array(this.inputDim).fill(0);
        }
        const bias2Grad = new Array(this.inputDim).fill(0);

        const weights1Grad = new Array(this.inputDim);
        for (let i = 0; i < this.inputDim; i++) {
            weights1Grad[i] = new Array(this.hiddenDim).fill(0);
        }
        const bias1Grad = new Array(this.hiddenDim).fill(0);

        for (let b = 0; b < batchSize; b++) {
            inputGradients[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                // Gradient through second linear layer
                const hiddenGrad = new Array(this.hiddenDim).fill(0);
                for (let h = 0; h < this.hiddenDim; h++) {
                    for (let o = 0; o < this.inputDim; o++) {
                        hiddenGrad[h] += outputGradients[b][t][o] * this.weights2[h][o];
                        weights2Grad[h][o] += this.cachedHidden[b][t][h] * outputGradients[b][t][o];
                    }
                }
                for (let o = 0; o < this.inputDim; o++) {
                    bias2Grad[o] += outputGradients[b][t][o];
                }

                // Gradient through GELU
                const preGelu = new Array(this.hiddenDim);
                for (let h = 0; h < this.hiddenDim; h++) {
                    // Recompute pre-activation
                    let sum = this.bias1[h];
                    for (let i = 0; i < this.inputDim; i++) {
                        sum += this.cachedInputs[b][t][i] * this.weights1[i][h];
                    }
                    preGelu[h] = sum;
                }

                const geluGrad = new Array(this.hiddenDim);
                for (let h = 0; h < this.hiddenDim; h++) {
                    const x = preGelu[h];
                    const cdf = 0.5 * (1 + Math.tanh(
                        Math.sqrt(2 / Math.PI) * (x + 0.044715 * x * x * x)
                    ));
                    const pdf = Math.exp(-0.5 * x * x) / Math.sqrt(2 * Math.PI);
                    geluGrad[h] = hiddenGrad[h] * (cdf + x * pdf);
                }

                // Gradient through first linear layer
                inputGradients[b][t] = new Array(this.inputDim).fill(0);
                for (let i = 0; i < this.inputDim; i++) {
                    for (let h = 0; h < this.hiddenDim; h++) {
                        inputGradients[b][t][i] += geluGrad[h] * this.weights1[i][h];
                        weights1Grad[i][h] += this.cachedInputs[b][t][i] * geluGrad[h];
                    }
                }
                for (let h = 0; h < this.hiddenDim; h++) {
                    bias1Grad[h] += geluGrad[h];
                }
            }
        }

        // Update weights
        for (let h = 0; h < this.hiddenDim; h++) {
            for (let o = 0; o < this.inputDim; o++) {
                this.weights2[h][o] -= learningRate * weights2Grad[h][o];
            }
        }
        for (let o = 0; o < this.inputDim; o++) {
            this.bias2[o] -= learningRate * bias2Grad[o];
        }
        for (let i = 0; i < this.inputDim; i++) {
            for (let h = 0; h < this.hiddenDim; h++) {
                this.weights1[i][h] -= learningRate * weights1Grad[i][h];
            }
        }
        for (let h = 0; h < this.hiddenDim; h++) {
            this.bias1[h] -= learningRate * bias1Grad[h];
        }

        return inputGradients;
    }
}

Batched Transformer Block

The transformer block combines attention, MLP, layer norms, and residual connections:

class TransformerBlock {
    attention = null;
    mlp = null;
    layerNorm1 = null;
    layerNorm2 = null;
    embeddingDim = 0;

    constructor(embeddingDim, numHeads) {
        this.embeddingDim = embeddingDim;
        this.attention = new MultiHeadAttention(embeddingDim, numHeads);
        this.mlp = new MLPBlock(embeddingDim);
        this.layerNorm1 = new LayerNormalization(embeddingDim);
        this.layerNorm2 = new LayerNormalization(embeddingDim);
    }

    forward(inputs) {
        // inputs: [batchSize][seqLength][embeddingDim]
        const batchSize = inputs.length;
        const seqLength = inputs[0].length;

        // Pre-norm architecture: norm -> attention -> residual
        const normed1 = this.layerNorm1.forward(inputs);
        const attended = this.attention.forward(normed1);

        // Residual connection
        const residual1 = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            residual1[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                residual1[b][t] = new Array(this.embeddingDim);
                for (let d = 0; d < this.embeddingDim; d++) {
                    residual1[b][t][d] = inputs[b][t][d] + attended[b][t][d];
                }
            }
        }

        // Pre-norm -> MLP -> residual
        const normed2 = this.layerNorm2.forward(residual1);
        const mlpOutput = this.mlp.forward(normed2);

        // Residual connection
        const output = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            output[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                output[b][t] = new Array(this.embeddingDim);
                for (let d = 0; d < this.embeddingDim; d++) {
                    output[b][t][d] = residual1[b][t][d] + mlpOutput[b][t][d];
                }
            }
        }

        return output;
    }

    backward(outputGradients, learningRate) {
        const batchSize = outputGradients.length;
        const seqLength = outputGradients[0].length;

        // Gradient flows through residual (just copy)
        const mlpGrad = this.mlp.backward(outputGradients, learningRate);
        const norm2Grad = this.layerNorm2.backward(mlpGrad, learningRate);

        // Add gradients from residual path
        const residual1Grad = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            residual1Grad[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                residual1Grad[b][t] = new Array(this.embeddingDim);
                for (let d = 0; d < this.embeddingDim; d++) {
                    residual1Grad[b][t][d] = outputGradients[b][t][d] + norm2Grad[b][t][d];
                }
            }
        }

        // Gradient through attention path
        const attentionGrad = this.attention.backward(residual1Grad, learningRate);
        const norm1Grad = this.layerNorm1.backward(attentionGrad, learningRate);

        // Add gradients from residual path
        const inputGradients = new Array(batchSize);
        for (let b = 0; b < batchSize; b++) {
            inputGradients[b] = new Array(seqLength);
            for (let t = 0; t < seqLength; t++) {
                inputGradients[b][t] = new Array(this.embeddingDim);
                for (let d = 0; d < this.embeddingDim; d++) {
                    inputGradients[b][t][d] = residual1Grad[b][t][d] + norm1Grad[b][t][d];
                }
            }
        }

        return inputGradients;
    }
}

Batched Output Layer

The output layer projects to vocabulary size and computes softmax:

class OutputLayer {
    weights = null;
    bias = null;
    inputDim = 0;
    vocabSize = 0;

    cachedInputs = null;
    cachedProbs = null;

    constructor(embeddingDim, vocabSize) {
        this.inputDim = embeddingDim;
        this.vocabSize = vocabSize;

        const scale = Math.sqrt(2.0 / embeddingDim);

        this.weights = new Array(embeddingDim);
        for (let i = 0; i < embeddingDim; i++) {
            this.weights[i] = new Array(vocabSize);
            for (let j = 0; j < vocabSize; j++) {
                this.weights[i][j] = (Math.random() * 2 - 1) * scale;
            }
        }

        this.bias = new Array(vocabSize).fill(0);
    }

    forward(inputs) {
        // inputs: [batchSize][seqLength][embeddingDim]
        const batchSize = inputs.length;
        const seqLength = inputs[0].length;

        this.cachedInputs = inputs;
        this.cachedProbs = new Array(batchSize);

        const output = new Array(batchSize);

        for (let b = 0; b < batchSize; b++) {
            output[b] = new Array(seqLength);
            this.cachedProbs[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                // Linear projection
                const logits = new Array(this.vocabSize);
                for (let v = 0; v < this.vocabSize; v++) {
                    let sum = this.bias[v];
                    for (let i = 0; i < this.inputDim; i++) {
                        sum += inputs[b][t][i] * this.weights[i][v];
                    }
                    logits[v] = sum;
                }

                // Softmax
                const maxLogit = Math.max(...logits);
                let sumExp = 0;
                const probs = new Array(this.vocabSize);
                for (let v = 0; v < this.vocabSize; v++) {
                    probs[v] = Math.exp(logits[v] - maxLogit);
                    sumExp += probs[v];
                }
                for (let v = 0; v < this.vocabSize; v++) {
                    probs[v] /= sumExp;
                }

                output[b][t] = probs;
                this.cachedProbs[b][t] = probs;
            }
        }

        return output;
    }

    backward(targetTokens, learningRate) {
        // targetTokens: [batchSize][seqLength]
        const batchSize = targetTokens.length;
        const seqLength = targetTokens[0].length;

        const inputGradients = new Array(batchSize);

        // Accumulate weight gradients
        const weightsGrad = new Array(this.inputDim);
        for (let i = 0; i < this.inputDim; i++) {
            weightsGrad[i] = new Array(this.vocabSize).fill(0);
        }
        const biasGrad = new Array(this.vocabSize).fill(0);

        for (let b = 0; b < batchSize; b++) {
            inputGradients[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                // Gradient of cross-entropy + softmax
                const grad = new Array(this.vocabSize);
                for (let v = 0; v < this.vocabSize; v++) {
                    grad[v] = this.cachedProbs[b][t][v];
                }
                grad[targetTokens[b][t]] -= 1;

                // Gradient w.r.t. inputs
                inputGradients[b][t] = new Array(this.inputDim).fill(0);
                for (let i = 0; i < this.inputDim; i++) {
                    for (let v = 0; v < this.vocabSize; v++) {
                        inputGradients[b][t][i] += grad[v] * this.weights[i][v];
                        weightsGrad[i][v] += this.cachedInputs[b][t][i] * grad[v];
                    }
                }

                for (let v = 0; v < this.vocabSize; v++) {
                    biasGrad[v] += grad[v];
                }
            }
        }

        // Update weights
        for (let i = 0; i < this.inputDim; i++) {
            for (let v = 0; v < this.vocabSize; v++) {
                this.weights[i][v] -= learningRate * weightsGrad[i][v];
            }
        }
        for (let v = 0; v < this.vocabSize; v++) {
            this.bias[v] -= learningRate * biasGrad[v];
        }

        return inputGradients;
    }

    computeLoss(probs, targetTokens) {
        // Average cross-entropy loss across batch and sequence
        const batchSize = probs.length;
        const seqLength = probs[0].length;

        let totalLoss = 0;
        for (let b = 0; b < batchSize; b++) {
            for (let t = 0; t < seqLength; t++) {
                const targetProb = probs[b][t][targetTokens[b][t]];
                totalLoss += -Math.log(targetProb + 1e-10);
            }
        }

        return totalLoss / (batchSize * seqLength);
    }
}

The Complete Batched GPT

Now we wire it all together:

class GabGPT {
    embedding = null;
    blocks = null;
    output = null;
    finalNorm = null;
    vocabSize = 0;

    constructor(vocabSize, embeddingDim, numHeads, numBlocks, maxSeqLength) {
        this.vocabSize = vocabSize;

        this.embedding = new PositionalEmbeddingLayer(vocabSize, embeddingDim, maxSeqLength);

        this.blocks = new Array(numBlocks);
        for (let i = 0; i < numBlocks; i++) {
            this.blocks[i] = new TransformerBlock(embeddingDim, numHeads);
        }

        this.finalNorm = new LayerNormalization(embeddingDim);
        this.output = new OutputLayer(embeddingDim, vocabSize);
    }

    forward(inputTokens) {
        // inputTokens: [batchSize][seqLength]
        let hidden = this.embedding.forward(inputTokens);

        for (let i = 0; i < this.blocks.length; i++) {
            hidden = this.blocks[i].forward(hidden);
        }

        hidden = this.finalNorm.forward(hidden);
        return this.output.forward(hidden);
    }

    backward(targetTokens, learningRate) {
        let gradients = this.output.backward(targetTokens, learningRate);
        gradients = this.finalNorm.backward(gradients, learningRate);

        for (let i = this.blocks.length - 1; i >= 0; i--) {
            gradients = this.blocks[i].backward(gradients, learningRate);
        }

        this.embedding.backward(gradients, learningRate);
    }

    train(inputTokens, targetTokens, learningRate) {
        const probs = this.forward(inputTokens);
        const loss = this.output.computeLoss(probs, targetTokens);
        this.backward(targetTokens, learningRate);
        return loss;
    }

    generate(promptTokens, maxLength) {
        // Generation still works on single sequences
        // Wrap in batch dimension
        let tokens = promptTokens.slice();

        for (let i = 0; i < maxLength; i++) {
            const batchedInput = [tokens];  // Batch size 1
            const probs = this.forward(batchedInput);
            const lastProbs = probs[0][tokens.length - 1];

            const nextToken = this.#sampleFromDistribution(lastProbs);
            tokens.push(nextToken);
        }

        return tokens;
    }

    #sampleFromDistribution(probs) {
        const random = Math.random();
        let cumulative = 0;

        for (let i = 0; i < probs.length; i++) {
            cumulative += probs[i];
            if (random < cumulative) {
                return i;
            }
        }

        return probs.length - 1;
    }
}

Training with Batches

The training loop now groups sequences into batches:

// Create batches from training data
function createBatches(tokens, batchSize, seqLength) {
    const batches = [];
    const numSequences = Math.floor((tokens.length - 1) / seqLength);

    for (let batchStart = 0; batchStart < numSequences; batchStart += batchSize) {
        const batchEnd = Math.min(batchStart + batchSize, numSequences);
        const inputBatch = [];
        const targetBatch = [];

        for (let i = batchStart; i < batchEnd; i++) {
            const start = i * seqLength;
            inputBatch.push(tokens.slice(start, start + seqLength));
            targetBatch.push(tokens.slice(start + 1, start + seqLength + 1));
        }

        batches.push({ inputs: inputBatch, targets: targetBatch });
    }

    return batches;
}

// Training loop
const batchSize = 4;
const seqLength = 32;
const batches = createBatches(tokens, batchSize, seqLength);

for (let epoch = 0; epoch < epochs; epoch++) {
    let totalLoss = 0;

    for (const batch of batches) {
        const loss = model.train(batch.inputs, batch.targets, learningRate);
        totalLoss += loss;
    }

    console.log(`Epoch ${epoch}: Loss = ${(totalLoss / batches.length).toFixed(4)}`);
}

Batching gives you two benefits: faster training (less overhead per sequence) and smoother gradients (noise averages out across examples).

Pre-Training

You've already been doing pre-training. The training loop from Part 4 that predicts the next token on raw text, that's pre-training. Pre-training is self-supervised learning. There are no human labels. The "label" for each position is simply the next token in the sequence. The model learns by trying to predict what comes next, over and over, on massive amounts of text.

Input:    "The cat sat on the"
Target:   "cat sat on the mat"
          ^    ^   ^   ^   ^
          Predict each of these from the previous tokens

What does the model learn from this? Everything it needs to predict well:

The model doesn't "know" it's learning these things. It's just minimizing prediction error. But to predict the next token in "The capital of France is", you'd better know geography.

Pre-training is where most of the compute goes. It creates the foundation—a model that understands language. Everything after is refinement.

Supervised Fine-Tuning

Pre-trained models have a problem: they complete text, but they don't converse. Ask "What is 2+2?" and you might get "What is 2+3? What is 2+4?" because the model is pattern-matching against math worksheets it saw during pre-training. Supervised Fine-Tuning (SFT) teaches the model to follow a conversation format. It's the same training algorithm, just with different data.

Adding Conversation Tokens

First, we need special tokens to mark conversation structure:

// Add to your tokenizer
tokenizer.reserveToken("<|user|>");
tokenizer.reserveToken("<|assistant|>");
tokenizer.reserveToken("<|end|>");
tokenizer.reserveToken("<|pad|>");  // For batching variable-length conversations

These tokens will appear in every training example, so the model will learn their meaning from context. SFT data is just conversations formatted as text:

const conversations = [
    {
        user: "What is the capital of France?",
        assistant: "The capital of France is Paris."
    },
    {
        user: "Who wrote Romeo and Juliet?",
        assistant: "William Shakespeare wrote Romeo and Juliet."
    },
    {
        user: "What is 2 + 2?",
        assistant: "2 + 2 equals 4."
    }
];

function formatConversation(conv) {
    return `<|user|>${conv.user}<|end|><|assistant|>${conv.assistant}<|end|>`;
}

// Convert all conversations to training data
const sftData = [];
for (const conv of conversations) {
    const text = formatConversation(conv);
    const tokens = tokenizer.encode(text);
    sftData.push(tokens);
}

When formatted, a conversation looks like:

<|user|>What is 2 + 2?<|end|><|assistant|>2 + 2 equals 4.<|end|>

The model learns: after <|assistant|>, generate a helpful response. After <|end|>, stop.

Padding

There's a problem with batching SFT data. Each conversation has a different length:

Conversation 1: "<|user|>Hi<|end|><|assistant|>Hello!<|end|>"        -> 12 tokens
Conversation 2: "<|user|>What is the capital of France?<|end|>..."  -> 28 tokens
Conversation 3: "<|user|>Why?<|end|><|assistant|>Because.<|end|>"   -> 10 tokens

Our batched code assumes rectangular arrays—every sequence in a batch has the same length. How do we handle this?

Padding is optional, but we will implement it anyway. JavaScript arrays can be jagged. Unlike PyTorch tensors, there's no requirement that batch[0].length === batch[1].length. We could rewrite every layer to handle variable-length sequences:

// Hypothetically, without padding:
for (let b = 0; b < batchSize; b++) {
    const seqLength = inputs[b].length;  // Each sequence has its own length
    for (let t = 0; t < seqLength; t++) {
        // ...
    }
}

This would actually work. Each sequence would process independently with its own length.

So why bother with padding? Three reasons:

  1. It's what real frameworks do. PyTorch and TensorFlow require rectangular tensors. Learning padding prepares you for production tools.
  2. GPU parallelism. Real training runs on GPUs where rectangular batches enable efficient parallel computation. Jagged arrays can't be parallelized the same way.
  3. Simpler code. Handling variable lengths means checking inputs[b].length everywhere. Easy to introduce bugs.

We'll implement padding to match industry practice. First, add a padding token to the vocabulary:

tokenizer.reserveToken("<|pad|>");

Padding a Batch

This function takes variable-length sequences and pads them to match the longest:

function padBatch(sequences, padTokenId) {
    // Find the longest sequence
    let maxLen = 0;
    for (let i = 0; i < sequences.length; i++) {
        if (sequences[i].length > maxLen) {
            maxLen = sequences[i].length;
        }
    }

    const paddedSequences = new Array(sequences.length);
    const paddingMask = new Array(sequences.length);  // 1 = real token, 0 = padding

    for (let i = 0; i < sequences.length; i++) {
        const seq = sequences[i];
        const padCount = maxLen - seq.length;

        // Copy original sequence and add padding
        paddedSequences[i] = new Array(maxLen);
        paddingMask[i] = new Array(maxLen);

        for (let j = 0; j < seq.length; j++) {
            paddedSequences[i][j] = seq[j];
            paddingMask[i][j] = 1;  // Real token
        }

        for (let j = seq.length; j < maxLen; j++) {
            paddedSequences[i][j] = padTokenId;
            paddingMask[i][j] = 0;  // Padding
        }
    }

    return { sequences: paddedSequences, mask: paddingMask };
}

The function returns both the padded sequences and a mask indicating which positions are real (1) vs padding (0).

Attention with Padding Mask

Padding tokens shouldn't participate in attention. A real token shouldn't attend to padding, and padding shouldn't attend to anything. We need to modify the attention layer to accept a padding mask and combine it with the causal mask:

// In MultiHeadAttention class
forward(inputs, paddingMask = null) {
    // inputs: [batchSize][seqLength][embeddingDim]
    // paddingMask: [batchSize][seqLength] - 1 for real, 0 for padding
    const batchSize = inputs.length;
    const seqLength = inputs[0].length;

    // ... Q, K, V projection code stays the same ...

    for (let b = 0; b < batchSize; b++) {
        // ... head loop setup ...

        for (let t = 0; t < seqLength; t++) {
            // Skip if this position is padding
            if (paddingMask && paddingMask[b][t] === 0) {
                // Output zeros for padding positions
                for (let d = 0; d < this.embeddingDim; d++) {
                    attentionOutputs[b][t][d] = 0;
                }
                continue;
            }

            const scores = new Array(seqLength);

            for (let s = 0; s < seqLength; s++) {
                // Can't attend to future (causal)
                if (s > t) {
                    scores[s] = -Infinity;
                }
                // Can't attend to padding
                else if (paddingMask && paddingMask[b][s] === 0) {
                    scores[s] = -Infinity;
                }
                else {
                    // Compute attention score normally
                    let score = 0;
                    for (let d = 0; d < this.headDim; d++) {
                        score += queries[b][t][headStart + d] * keys[b][s][headStart + d];
                    }
                    scores[s] = score * scale;
                }
            }

            // ... softmax and value aggregation stay the same ...
        }
    }

    // ... output projection ...
}

The changes:

  1. If the current position is padding, output zeros and skip
  2. When computing attention scores, set padding positions to -Infinity so they get zero weight after softmax

Padding Mask Through the Network

The padding mask needs to flow through the entire network. Update the transformer block and model:

// TransformerBlock
forward(inputs, paddingMask = null) {
    const normed1 = this.layerNorm1.forward(inputs);
    const attended = this.attention.forward(normed1, paddingMask);  // Pass mask
    // ... residual, MLP, etc ...
}

// GabGPT
forward(inputTokens, paddingMask = null) {
    let hidden = this.embedding.forward(inputTokens);

    for (let i = 0; i < this.blocks.length; i++) {
        hidden = this.blocks[i].forward(hidden, paddingMask);  // Pass mask
    }

    hidden = this.finalNorm.forward(hidden);
    return this.output.forward(hidden);
}

What about layer normalization? Technically, padding positions shouldn't contribute to the mean and variance calculation. In practice, padding values are often zeros, so the effect is minimal. We'll accept this small imprecision to keep the code simpler. Production systems handle this more carefully.

Loss Masking

There's a chat refinement: we only care that the model generates good assistant responses. The user's text is given, we don't need to train on predicting it. Loss masking zeros out the loss on user tokens, focusing all learning on the assistant's output:

function createLossMask(tokens, tokenizer) {
    // Get special token IDs
    const userTokenId = tokenizer.encode("<|user|>")[0];
    const assistantTokenId = tokenizer.encode("<|assistant|>")[0];
    const endTokenId = tokenizer.encode("<|end|>")[0];
    const padTokenId = tokenizer.encode("<|pad|>")[0];

    const mask = new Array(tokens.length).fill(0);
    let inAssistantTurn = false;

    for (let i = 0; i < tokens.length; i++) {
        // Padding tokens are never included in loss
        if (tokens[i] === padTokenId) {
            mask[i] = 0;
            continue;
        }

        if (tokens[i] === assistantTokenId) {
            inAssistantTurn = true;
        } else if (tokens[i] === userTokenId) {
            inAssistantTurn = false;
        } else if (tokens[i] === endTokenId) {
            // Include the end token in assistant's mask, then switch off
            if (inAssistantTurn) {
                mask[i] = 1;
            }
            inAssistantTurn = false;
        }

        if (inAssistantTurn) {
            mask[i] = 1;
        }
    }

    return mask;
}

For the example "What is 2 + 2?":

Tokens: <|user|> What is 2 + 2 ? <|end|> <|assistant|> 2 + 2 equals 4 . <|end|>
Mask:      0      0    0  0 0 0 0   0         0        1 1 1   1    1 1   1

Only positions with mask=1 contribute to the loss.

Training with Loss Masking

Modify the output layer to accept a mask:

// In OutputLayer class
computeLossWithMask(probs, targetTokens, mask) {
    const batchSize = probs.length;
    const seqLength = probs[0].length;

    let totalLoss = 0;
    let count = 0;

    for (let b = 0; b < batchSize; b++) {
        for (let t = 0; t < seqLength; t++) {
            if (mask[b][t] === 1) {
                const targetProb = probs[b][t][targetTokens[b][t]];
                totalLoss += -Math.log(targetProb + 1e-10);
                count++;
            }
        }
    }

    return count > 0 ? totalLoss / count : 0;
}

backwardWithMask(targetTokens, mask, learningRate) {
    const batchSize = targetTokens.length;
    const seqLength = targetTokens[0].length;

    const inputGradients = new Array(batchSize);

    // Weight gradients (accumulated)
    const weightsGrad = new Array(this.inputDim);
    for (let i = 0; i < this.inputDim; i++) {
        weightsGrad[i] = new Array(this.vocabSize).fill(0);
    }
    const biasGrad = new Array(this.vocabSize).fill(0);

    for (let b = 0; b < batchSize; b++) {
        inputGradients[b] = new Array(seqLength);

        for (let t = 0; t < seqLength; t++) {
            if (mask[b][t] === 0) {
                // No gradient for masked positions
                inputGradients[b][t] = new Array(this.inputDim).fill(0);
                continue;
            }

            // Same gradient computation as before
            const grad = new Array(this.vocabSize);
            for (let v = 0; v < this.vocabSize; v++) {
                grad[v] = this.cachedProbs[b][t][v];
            }
            grad[targetTokens[b][t]] -= 1;

            inputGradients[b][t] = new Array(this.inputDim).fill(0);
            for (let i = 0; i < this.inputDim; i++) {
                for (let v = 0; v < this.vocabSize; v++) {
                    inputGradients[b][t][i] += grad[v] * this.weights[i][v];
                    weightsGrad[i][v] += this.cachedInputs[b][t][i] * grad[v];
                }
            }

            for (let v = 0; v < this.vocabSize; v++) {
                biasGrad[v] += grad[v];
            }
        }
    }

    // Update weights
    for (let i = 0; i < this.inputDim; i++) {
        for (let v = 0; v < this.vocabSize; v++) {
            this.weights[i][v] -= learningRate * weightsGrad[i][v];
        }
    }
    for (let v = 0; v < this.vocabSize; v++) {
        this.bias[v] -= learningRate * biasGrad[v];
    }

    return inputGradients;
}

The Complete SFT Training Loop

Now we combine padding and loss masking:

// Get pad token ID
const padTokenId = tokenizer.encode("<|pad|>")[0];

// Prepare SFT batches with padding
function createSFTBatches(sftData, batchSize, tokenizer, padTokenId) {
    const batches = [];

    for (let i = 0; i < sftData.length; i += batchSize) {
        const batchEnd = Math.min(i + batchSize, sftData.length);

        // Collect raw sequences for this batch
        const rawInputs = [];
        const rawTargets = [];

        for (let j = i; j < batchEnd; j++) {
            const tokens = sftData[j];
            rawInputs.push(tokens.slice(0, -1));   // All but last
            rawTargets.push(tokens.slice(1));      // All but first
        }

        // Pad inputs and targets to same length
        const { sequences: paddedInputs, mask: paddingMask } = padBatch(rawInputs, padTokenId);
        const { sequences: paddedTargets } = padBatch(rawTargets, padTokenId);

        // Create loss masks (handles both assistant-only and padding)
        const lossMasks = new Array(paddedTargets.length);
        for (let j = 0; j < paddedTargets.length; j++) {
            // Create mask on the padded target sequence
            lossMasks[j] = createLossMask(paddedTargets[j], tokenizer);
        }

        batches.push({
            inputs: paddedInputs,
            targets: paddedTargets,
            paddingMask: paddingMask,  // For attention
            lossMask: lossMasks        // For loss computation
        });
    }

    return batches;
}

const sftBatches = createSFTBatches(sftData, batchSize, tokenizer, padTokenId);

// SFT training loop
const sftEpochs = 100;
const sftLearningRate = 0.001;  // Lower than pre-training!

for (let epoch = 0; epoch < sftEpochs; epoch++) {
    let totalLoss = 0;

    for (const batch of sftBatches) {
        // Forward pass with padding mask
        const probs = model.forward(batch.inputs, batch.paddingMask);

        // Compute loss only on non-padded assistant tokens
        const loss = model.output.computeLossWithMask(probs, batch.targets, batch.lossMask);

        // Backward pass with loss mask
        model.output.backwardWithMask(batch.targets, batch.lossMask, sftLearningRate);

        // Backward through rest of network
        let gradients = model.output.inputGradients;
        gradients = model.finalNorm.backward(gradients, sftLearningRate);
        for (let i = model.blocks.length - 1; i >= 0; i--) {
            gradients = model.blocks[i].backward(gradients, sftLearningRate);
        }
        model.embedding.backward(gradients, sftLearningRate);

        totalLoss += loss;
    }

    if (epoch % 10 === 0) {
        console.log(`SFT Epoch ${epoch}: Loss = ${(totalLoss / sftBatches.length).toFixed(4)}`);
    }
}

Notice the lower learning rate. SFT builds on pre-trained knowledge—we don't want to destroy what the model already learned.

The padding mask ensures attention doesn't leak into padding positions. The loss mask ensures we only train on assistant tokens (not user text, not padding). Together, they make SFT work correctly on batched, variable-length conversations.

User / Assistant Roles

The special tokens create implicit "roles." After seeing thousands of examples where helpful responses follow <|assistant|>, the model learns the pattern.

During generation, we format the prompt the same way:

function chat(model, tokenizer, userMessage) {
    const prompt = `<|user|>${userMessage}<|end|><|assistant|>`;
    const promptTokens = tokenizer.encode(prompt);

    const generated = model.generate(promptTokens, 50);
    const response = tokenizer.decode(generated);

    // Extract just the assistant's response
    const assistantStart = response.indexOf("<|assistant|>") + "<|assistant|>".length;
    const assistantEnd = response.indexOf("<|end|>", assistantStart);

    return response.slice(assistantStart, assistantEnd !== -1 ? assistantEnd : undefined);
}

// Usage
const response = chat(model, tokenizer, "What is the capital of France?");
console.log(response);  // "The capital of France is Paris."

The model sees the familiar pattern—<|user|> followed by a question, then <|assistant|>—and generates what it learned comes next: a helpful answer.

Thinking

You've probably seen models that show their reasoning in <think> tags before giving an answer. This isn't a special architecture—it's just SFT with a different format.

const thinkingConversations = [
    {
        user: "What is 15 + 27?",
        thinking: "I need to add 15 and 27. 15 + 27 = 42.",
        assistant: "15 + 27 equals 42."
    },
    {
        user: "Is a whale a fish?",
        thinking: "Whales live in water, but they breathe air and nurse their young. These are mammal characteristics.",
        assistant: "No, a whale is not a fish. Whales are mammals."
    }
];

function formatThinkingConversation(conv) {
    return `<|user|>${conv.user}<|end|><|think|>${conv.thinking}<|end|><|assistant|>${conv.assistant}<|end|>`;
}

Add <|think|> as a reserved token, format your training data with thinking sections, and the model learns to "think" before responding. The thinking isn't magic cognition—it's the model generating intermediate text that helps it produce better final answers. By writing out reasoning steps, the model conditions its output on that reasoning, leading to more accurate responses.

Thinking tokens work because transformers are autoregressive. Each generated token influences the next. Generating "15 + 27 = 42" gives the model something concrete to reference when generating the final answer.

Training Tricks

Our simple gradient descent works, but there are tricks that make training faster and more stable.

Learning Rate Warmup & Decay

Using the same learning rate throughout training is suboptimal. At the start, weights are random, large updates might send them in bad directions. At the end, we want to fine-tune carefully. Warmup starts with a tiny learning rate and increases it over the first few steps. Decay decreases it as training progresses.

function getLearningRate(step, warmupSteps, totalSteps, maxLR) {
    if (step < warmupSteps) {
        // Linear warmup: 0 -> maxLR
        return maxLR * (step / warmupSteps);
    }

    // Cosine decay: maxLR -> 0
    const progress = (step - warmupSteps) / (totalSteps - warmupSteps);
    return maxLR * 0.5 * (1 + Math.cos(Math.PI * progress));
}

The cosine shape is smooth—no sudden changes that might destabilize training. Usage:

const warmupSteps = 100;
const totalSteps = epochs * batches.length;
const maxLR = 0.001;

let step = 0;
for (let epoch = 0; epoch < epochs; epoch++) {
    for (const batch of batches) {
        const lr = getLearningRate(step, warmupSteps, totalSteps, maxLR);
        const loss = model.train(batch.inputs, batch.targets, lr);
        step++;
    }
}

Gradient Clipping

Sometimes gradients explode. They become huge, causing massive weight updates that destroy the model. This happens especially with deep networks or long sequences.

Gradient clipping caps gradient magnitude:

function clipGradient(gradient, maxValue) {
    return Math.max(-maxValue, Math.min(maxValue, gradient));
}

Apply this wherever you compute gradients. For example, in a backward pass:

for (let i = 0; i < this.inputDim; i++) {
    for (let h = 0; h < this.hiddenDim; h++) {
        let grad = this.cachedInputs[b][t][i] * geluGrad[h];
        grad = clipGradient(grad, 1.0);  // Clip to [-1, 1]
        weights1Grad[i][h] += grad;
    }
}

A cleaner approach is to clip the global gradient norm, compute the total magnitude of all gradients, and scale them down if it exceeds a threshold. But per-value clipping is simpler and often sufficient.

The Adam Optimizer

Gradient descent uses the same learning rate for every parameter. Adam (Adaptive Moment Estimation) adapts the learning rate per-parameter based on gradient history. Adam tracks two quantities for each parameter:

Parameters with large gradients get smaller updates. Parameters with small, consistent gradients get larger updates. To use Adam, each layer needs to store optimizer state. We'll add this to our layers. First, let's create an Adam helper class to keep the logic centralized:

class AdamOptimizer {
    learningRate = 0.001;
    beta1 = 0.9;      // Momentum decay
    beta2 = 0.999;    // Squared gradient decay
    epsilon = 1e-8;   // Prevents division by zero
    step = 0;

    constructor(learningRate = 0.001) {
        this.learningRate = learningRate;
    }

    createState(shape) {
        // Initialize m and v to zeros with the given shape
        if (typeof shape === 'number') {
            return {
                m: new Array(shape).fill(0),
                v: new Array(shape).fill(0)
            };
        } else if (shape.length === 2) {
            const m = new Array(shape[0]);
            const v = new Array(shape[0]);
            for (let i = 0; i < shape[0]; i++) {
                m[i] = new Array(shape[1]).fill(0);
                v[i] = new Array(shape[1]).fill(0);
            }
            return { m, v };
        }
    }

    update1D(params, grads, state) {
        this.step++;

        for (let i = 0; i < params.length; i++) {
            // Update biased first moment estimate
            state.m[i] = this.beta1 * state.m[i] + (1 - this.beta1) * grads[i];

            // Update biased second moment estimate
            state.v[i] = this.beta2 * state.v[i] + (1 - this.beta2) * grads[i] * grads[i];

            // Bias correction
            const mHat = state.m[i] / (1 - Math.pow(this.beta1, this.step));
            const vHat = state.v[i] / (1 - Math.pow(this.beta2, this.step));

            // Update parameter
            params[i] -= this.learningRate * mHat / (Math.sqrt(vHat) + this.epsilon);
        }
    }

    update2D(params, grads, state) {
        this.step++;

        for (let i = 0; i < params.length; i++) {
            for (let j = 0; j < params[i].length; j++) {
                state.m[i][j] = this.beta1 * state.m[i][j] + (1 - this.beta1) * grads[i][j];
                state.v[i][j] = this.beta2 * state.v[i][j] + (1 - this.beta2) * grads[i][j] * grads[i][j];

                const mHat = state.m[i][j] / (1 - Math.pow(this.beta1, this.step));
                const vHat = state.v[i][j] / (1 - Math.pow(this.beta2, this.step));

                params[i][j] -= this.learningRate * mHat / (Math.sqrt(vHat) + this.epsilon);
            }
        }
    }
}

Now let's modify the MLP block to use Adam. The pattern is: store optimizer state, accumulate gradients, then call the optimizer to update:

class MLPBlock {
    weights1 = null;
    bias1 = null;
    weights2 = null;
    bias2 = null;

    inputDim = 0;
    hiddenDim = 0;

    cachedInputs = null;
    cachedHidden = null;

    // Adam state
    optimizer = null;
    weights1State = null;
    bias1State = null;
    weights2State = null;
    bias2State = null;

    constructor(embeddingDim, optimizer = null) {
        this.inputDim = embeddingDim;
        this.hiddenDim = embeddingDim * 4;
        this.optimizer = optimizer;

        const scale1 = Math.sqrt(2.0 / embeddingDim);
        const scale2 = Math.sqrt(2.0 / this.hiddenDim);

        this.weights1 = this.#initWeights(embeddingDim, this.hiddenDim, scale1);
        this.bias1 = new Array(this.hiddenDim).fill(0);
        this.weights2 = this.#initWeights(this.hiddenDim, embeddingDim, scale2);
        this.bias2 = new Array(embeddingDim).fill(0);

        // Initialize Adam state if optimizer provided
        if (optimizer) {
            this.weights1State = optimizer.createState([embeddingDim, this.hiddenDim]);
            this.bias1State = optimizer.createState(this.hiddenDim);
            this.weights2State = optimizer.createState([this.hiddenDim, embeddingDim]);
            this.bias2State = optimizer.createState(embeddingDim);
        }
    }

    #initWeights(inputSize, outputSize, scale) {
        const weights = new Array(inputSize);
        for (let i = 0; i < inputSize; i++) {
            weights[i] = new Array(outputSize);
            for (let j = 0; j < outputSize; j++) {
                weights[i][j] = (Math.random() * 2 - 1) * scale;
            }
        }
        return weights;
    }

    forward(inputs) {
        // Same as before
        const batchSize = inputs.length;
        const seqLength = inputs[0].length;

        this.cachedInputs = inputs;
        this.cachedHidden = new Array(batchSize);
        const output = new Array(batchSize);

        for (let b = 0; b < batchSize; b++) {
            this.cachedHidden[b] = new Array(seqLength);
            output[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                const hidden = new Array(this.hiddenDim);
                for (let h = 0; h < this.hiddenDim; h++) {
                    let sum = this.bias1[h];
                    for (let i = 0; i < this.inputDim; i++) {
                        sum += inputs[b][t][i] * this.weights1[i][h];
                    }
                    // GELU
                    hidden[h] = sum * 0.5 * (1 + Math.tanh(
                        Math.sqrt(2 / Math.PI) * (sum + 0.044715 * sum * sum * sum)
                    ));
                }
                this.cachedHidden[b][t] = hidden;

                output[b][t] = new Array(this.inputDim);
                for (let o = 0; o < this.inputDim; o++) {
                    let sum = this.bias2[o];
                    for (let h = 0; h < this.hiddenDim; h++) {
                        sum += hidden[h] * this.weights2[h][o];
                    }
                    output[b][t][o] = sum;
                }
            }
        }

        return output;
    }

    backward(outputGradients) {
        const batchSize = outputGradients.length;
        const seqLength = outputGradients[0].length;

        const inputGradients = new Array(batchSize);

        // Initialize gradient accumulators
        const weights2Grad = new Array(this.hiddenDim);
        for (let h = 0; h < this.hiddenDim; h++) {
            weights2Grad[h] = new Array(this.inputDim).fill(0);
        }
        const bias2Grad = new Array(this.inputDim).fill(0);

        const weights1Grad = new Array(this.inputDim);
        for (let i = 0; i < this.inputDim; i++) {
            weights1Grad[i] = new Array(this.hiddenDim).fill(0);
        }
        const bias1Grad = new Array(this.hiddenDim).fill(0);

        for (let b = 0; b < batchSize; b++) {
            inputGradients[b] = new Array(seqLength);

            for (let t = 0; t < seqLength; t++) {
                // Gradient through second linear
                const hiddenGrad = new Array(this.hiddenDim).fill(0);
                for (let h = 0; h < this.hiddenDim; h++) {
                    for (let o = 0; o < this.inputDim; o++) {
                        hiddenGrad[h] += outputGradients[b][t][o] * this.weights2[h][o];
                        weights2Grad[h][o] += this.cachedHidden[b][t][h] * outputGradients[b][t][o];
                    }
                }
                for (let o = 0; o < this.inputDim; o++) {
                    bias2Grad[o] += outputGradients[b][t][o];
                }

                // Gradient through GELU
                const preGelu = new Array(this.hiddenDim);
                for (let h = 0; h < this.hiddenDim; h++) {
                    let sum = this.bias1[h];
                    for (let i = 0; i < this.inputDim; i++) {
                        sum += this.cachedInputs[b][t][i] * this.weights1[i][h];
                    }
                    preGelu[h] = sum;
                }

                const geluGrad = new Array(this.hiddenDim);
                for (let h = 0; h < this.hiddenDim; h++) {
                    const x = preGelu[h];
                    const cdf = 0.5 * (1 + Math.tanh(
                        Math.sqrt(2 / Math.PI) * (x + 0.044715 * x * x * x)
                    ));
                    const pdf = Math.exp(-0.5 * x * x) / Math.sqrt(2 * Math.PI);
                    geluGrad[h] = hiddenGrad[h] * (cdf + x * pdf);
                }

                // Gradient through first linear
                inputGradients[b][t] = new Array(this.inputDim).fill(0);
                for (let i = 0; i < this.inputDim; i++) {
                    for (let h = 0; h < this.hiddenDim; h++) {
                        inputGradients[b][t][i] += geluGrad[h] * this.weights1[i][h];
                        weights1Grad[i][h] += this.cachedInputs[b][t][i] * geluGrad[h];
                    }
                }
                for (let h = 0; h < this.hiddenDim; h++) {
                    bias1Grad[h] += geluGrad[h];
                }
            }
        }

        // Update weights with Adam or SGD
        if (this.optimizer) {
            this.optimizer.update2D(this.weights2, weights2Grad, this.weights2State);
            this.optimizer.update1D(this.bias2, bias2Grad, this.bias2State);
            this.optimizer.update2D(this.weights1, weights1Grad, this.weights1State);
            this.optimizer.update1D(this.bias1, bias1Grad, this.bias1State);
        }

        return inputGradients;
    }
}

The key changes:

  1. Store optimizer state alongside weights
  2. Accumulate gradients instead of applying them immediately
  3. Call optimizer.update() at the end of backward()

Apply the same pattern to other layers: embedding, attention, output layer, layer normalization.

Using Adam in Training

// Create shared optimizer
const optimizer = new AdamOptimizer(0.001);

// Pass to model components
const model = new GabGPT(vocabSize, embeddingDim, numHeads, numBlocks, maxSeqLength, optimizer);

// Training loop no longer needs to pass learning rate to train()
for (let epoch = 0; epoch < epochs; epoch++) {
    let totalLoss = 0;

    for (const batch of batches) {
        const loss = model.train(batch.inputs, batch.targets);  // No LR argument
        totalLoss += loss;
    }

    console.log(`Epoch ${epoch}: Loss = ${(totalLoss / batches.length).toFixed(4)}`);
}

With Adam, you can often use a fixed learning rate (like 0.001) without warmup. But the best results come from combining Adam with warmup and cosine decay—modify the optimizer's learning rate per step.

Putting It All Together

Let's train a small model from scratch through pre-training and SFT:

// ==========================================
// Complete Training Pipeline
// ==========================================

// 1. Initialize tokenizer
const tokenizer = new Tokenizer();
tokenizer.reserveToken("<|user|>");
tokenizer.reserveToken("<|assistant|>");
tokenizer.reserveToken("<|end|>");
tokenizer.reserveToken("<|pad|>");
tokenizer.reserveToken("<|think|>");

// 2. Pre-training data
const pretrainingText = `
The cat sat on the mat. The dog sat on the log.
The quick brown fox jumps over the lazy dog.
In the morning, the sun rises in the east.
In the evening, the sun sets in the west.
Paris is the capital of France.
London is the capital of the United Kingdom.
The Earth orbits around the Sun.
Water freezes at zero degrees Celsius.
`;

tokenizer.train(pretrainingText, 100);
const pretrainingTokens = tokenizer.encode(pretrainingText);
console.log(`Vocabulary size: ${tokenizer.getVocabSize()}`);
console.log(`Pre-training tokens: ${pretrainingTokens.length}`);

// Get pad token ID for later
const padTokenId = tokenizer.encode("<|pad|>")[0];

// 3. Create model with Adam optimizer
const embeddingDim = 64;
const numHeads = 4;
const numBlocks = 2;
const maxSeqLength = 64;
const optimizer = new AdamOptimizer(0.001);

const model = new GabGPT(
    tokenizer.getVocabSize(),
    embeddingDim,
    numHeads,
    numBlocks,
    maxSeqLength,
    optimizer
);

// 4. Pre-training
console.log("\n=== Pre-training ===");
const batchSize = 4;
const seqLength = 32;
const pretrainingBatches = createBatches(pretrainingTokens, batchSize, seqLength);
const pretrainingEpochs = 100;

for (let epoch = 0; epoch < pretrainingEpochs; epoch++) {
    let totalLoss = 0;

    for (const batch of pretrainingBatches) {
        const loss = model.train(batch.inputs, batch.targets);
        totalLoss += loss;
    }

    if (epoch % 20 === 0) {
        console.log(`Epoch ${epoch}: Loss = ${(totalLoss / pretrainingBatches.length).toFixed(4)}`);
    }
}

// 5. SFT data
const sftConversations = [
    { user: "What is the capital of France?", assistant: "The capital of France is Paris." },
    { user: "What is the capital of the UK?", assistant: "The capital of the United Kingdom is London." },
    { user: "When does the sun rise?", assistant: "The sun rises in the morning, in the east." },
    { user: "What does water do at zero degrees?", assistant: "Water freezes at zero degrees Celsius." },
];

const sftData = [];
for (const conv of sftConversations) {
    const text = `<|user|>${conv.user}<|end|><|assistant|>${conv.assistant}<|end|>`;
    sftData.push(tokenizer.encode(text));
}

// 6. SFT training with padding
console.log("\n=== Supervised Fine-Tuning ===");
optimizer.learningRate = 0.0001;  // Lower for fine-tuning
const sftEpochs = 200;

// Create batches with padding
const sftBatches = createSFTBatches(sftData, batchSize, tokenizer, padTokenId);

for (let epoch = 0; epoch < sftEpochs; epoch++) {
    let totalLoss = 0;

    for (const batch of sftBatches) {
        // Forward with padding mask
        const probs = model.forward(batch.inputs, batch.paddingMask);
        const loss = model.output.computeLossWithMask(probs, batch.targets, batch.lossMask);

        // Backward with loss mask
        model.output.backwardWithMask(batch.targets, batch.lossMask);
        let gradients = model.output.inputGradients;
        gradients = model.finalNorm.backward(gradients);
        for (let i = model.blocks.length - 1; i >= 0; i--) {
            gradients = model.blocks[i].backward(gradients);
        }
        model.embedding.backward(gradients);

        totalLoss += loss;
    }

    if (epoch % 50 === 0) {
        console.log(`Epoch ${epoch}: Loss = ${(totalLoss / sftBatches.length).toFixed(4)}`);
    }
}

// 7. Test the chatbot
console.log("\n=== Testing ===");
const testQuestions = [
    "What is the capital of France?",
    "When does the sun set?",
];

for (const question of testQuestions) {
    const response = chat(model, tokenizer, question);
    console.log(`User: ${question}`);
    console.log(`Assistant: ${response}\n`);
}

Resources