Part 5: Training
We've built a working transformer. It can predict the next token, generate text, and learn patterns from data. But it's not a ChatGPT yet. It processes one sequence at a time and knows nothing about conversations.
This part fixes all of that. We'll implement batch training, learn how pre-training differs from fine-tuning, teach the model to chat using SFT, and add the optimization tricks that make training actually work at scale.
By the end, you'll have a model that responds to prompts like a chatbot, not just a text completion engine.
Training with Batches
The training loop from Part 4 processes one sequence at a time:
for (let start = 0; start < tokens.length - sequenceLength - 1; start += 8) {
const inputTokens = tokens.slice(start, start + sequenceLength);
const targetTokens = tokens.slice(start + 1, start + sequenceLength + 1);
const loss = model.train(inputTokens, targetTokens, learningRate);
}
This works, but it's inefficient and produces noisy gradients. A single example might push weights in a weird direction; averaging across many examples gives a cleaner signal. Batch training processes multiple sequences simultaneously. Instead of shape [seqLength], the inputs become [batchSize][seqLength]. Every layer must handle this extra dimension.
Single sequence: [16] -> "The cat sat on..."
Batch of 4: [4][16] -> 4 different sequences at once
After embedding: [4][16][64] -> 4 sequences, 16 positions, 64-dim embeddings
Batched Embedding Layer
The embedding layer needs minimal changes—we just add an outer loop over the batch:
class EmbeddingLayer {
weights = null;
vocabSize = 0;
embeddingDim = 0;
cachedInputTokens = null;
constructor(vocabSize, embeddingDim) {
this.vocabSize = vocabSize;
this.embeddingDim = embeddingDim;
const scale = Math.sqrt(1.0 / embeddingDim);
this.weights = new Array(vocabSize);
for (let i = 0; i < vocabSize; i++) {
this.weights[i] = new Array(embeddingDim);
for (let j = 0; j < embeddingDim; j++) {
this.weights[i][j] = (Math.random() * 2 - 1) * scale;
}
}
}
forward(inputTokens) {
// inputTokens: [batchSize][seqLength]
const batchSize = inputTokens.length;
const seqLength = inputTokens[0].length;
this.cachedInputTokens = inputTokens;
// Output: [batchSize][seqLength][embeddingDim]
const output = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
output[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
const tokenId = inputTokens[b][t];
output[b][t] = new Array(this.embeddingDim);
for (let d = 0; d < this.embeddingDim; d++) {
output[b][t][d] = this.weights[tokenId][d];
}
}
}
return output;
}
backward(outputGradients, learningRate) {
// outputGradients: [batchSize][seqLength][embeddingDim]
const batchSize = this.cachedInputTokens.length;
const seqLength = this.cachedInputTokens[0].length;
for (let b = 0; b < batchSize; b++) {
for (let t = 0; t < seqLength; t++) {
const tokenId = this.cachedInputTokens[b][t];
for (let d = 0; d < this.embeddingDim; d++) {
this.weights[tokenId][d] -= learningRate * outputGradients[b][t][d];
}
}
}
return null;
}
}
The forward pass looks up embeddings for each token in each sequence. The backward pass accumulates gradients—if the same token appears in multiple sequences (or multiple positions), its embedding gets updated multiple times.
Batched Positional Embedding Layer
Same pattern, wrap everything in a batch loop:
class PositionalEmbeddingLayer {
tokenEmbedding = null;
positionWeights = null;
maxSequenceLength = 0;
embeddingDim = 0;
cachedBatchSize = 0;
cachedSeqLength = 0;
constructor(vocabSize, embeddingDim, maxSequenceLength) {
this.embeddingDim = embeddingDim;
this.maxSequenceLength = maxSequenceLength;
this.tokenEmbedding = new EmbeddingLayer(vocabSize, embeddingDim);
const scale = Math.sqrt(1.0 / embeddingDim);
this.positionWeights = new Array(maxSequenceLength);
for (let pos = 0; pos < maxSequenceLength; pos++) {
this.positionWeights[pos] = new Array(embeddingDim);
for (let d = 0; d < embeddingDim; d++) {
this.positionWeights[pos][d] = (Math.random() * 2 - 1) * scale;
}
}
}
forward(inputTokens) {
// inputTokens: [batchSize][seqLength]
const tokenEmbeddings = this.tokenEmbedding.forward(inputTokens);
const batchSize = inputTokens.length;
const seqLength = inputTokens[0].length;
this.cachedBatchSize = batchSize;
this.cachedSeqLength = seqLength;
// Add position embeddings (same positions for all sequences in batch)
const output = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
output[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
output[b][t] = new Array(this.embeddingDim);
for (let d = 0; d < this.embeddingDim; d++) {
output[b][t][d] = tokenEmbeddings[b][t][d] + this.positionWeights[t][d];
}
}
}
return output;
}
backward(outputGradients, learningRate) {
// Position gradients: sum across batch (same position used by all sequences)
for (let t = 0; t < this.cachedSeqLength; t++) {
for (let d = 0; d < this.embeddingDim; d++) {
let grad = 0;
for (let b = 0; b < this.cachedBatchSize; b++) {
grad += outputGradients[b][t][d];
}
this.positionWeights[t][d] -= learningRate * grad;
}
}
this.tokenEmbedding.backward(outputGradients, learningRate);
return null;
}
}
Notice how position gradients sum across the batch. Position 0's embedding is used by every sequence, so its gradient accumulates from all of them.
Batched Layer Normalization
Layer normalization computes statistics per position, per sequence. With batches, we normalize each [embeddingDim] vector independently:
class LayerNormalization {
gamma = null;
beta = null;
featureSize = 0;
epsilon = 1e-5;
cachedInputs = null;
cachedMean = null;
cachedVariance = null;
cachedNormalized = null;
constructor(featureSize) {
this.featureSize = featureSize;
this.gamma = new Array(featureSize);
this.beta = new Array(featureSize);
for (let i = 0; i < featureSize; i++) {
this.gamma[i] = 1.0;
this.beta[i] = 0.0;
}
}
forward(inputs) {
// inputs: [batchSize][seqLength][featureSize]
const batchSize = inputs.length;
const seqLength = inputs[0].length;
this.cachedInputs = inputs;
this.cachedMean = new Array(batchSize);
this.cachedVariance = new Array(batchSize);
this.cachedNormalized = new Array(batchSize);
const output = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
this.cachedMean[b] = new Array(seqLength);
this.cachedVariance[b] = new Array(seqLength);
this.cachedNormalized[b] = new Array(seqLength);
output[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
// Compute mean
let mean = 0;
for (let i = 0; i < this.featureSize; i++) {
mean += inputs[b][t][i];
}
mean /= this.featureSize;
this.cachedMean[b][t] = mean;
// Compute variance
let variance = 0;
for (let i = 0; i < this.featureSize; i++) {
const diff = inputs[b][t][i] - mean;
variance += diff * diff;
}
variance /= this.featureSize;
this.cachedVariance[b][t] = variance;
// Normalize and apply gamma/beta
const stdInv = 1.0 / Math.sqrt(variance + this.epsilon);
this.cachedNormalized[b][t] = new Array(this.featureSize);
output[b][t] = new Array(this.featureSize);
for (let i = 0; i < this.featureSize; i++) {
const normalized = (inputs[b][t][i] - mean) * stdInv;
this.cachedNormalized[b][t][i] = normalized;
output[b][t][i] = this.gamma[i] * normalized + this.beta[i];
}
}
}
return output;
}
backward(outputGradients, learningRate) {
const batchSize = outputGradients.length;
const seqLength = outputGradients[0].length;
const inputGradients = new Array(batchSize);
// Accumulate gamma/beta gradients across batch and sequence
const gammaGrad = new Array(this.featureSize).fill(0);
const betaGrad = new Array(this.featureSize).fill(0);
for (let b = 0; b < batchSize; b++) {
inputGradients[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
const mean = this.cachedMean[b][t];
const variance = this.cachedVariance[b][t];
const stdInv = 1.0 / Math.sqrt(variance + this.epsilon);
// Accumulate parameter gradients
for (let i = 0; i < this.featureSize; i++) {
gammaGrad[i] += outputGradients[b][t][i] * this.cachedNormalized[b][t][i];
betaGrad[i] += outputGradients[b][t][i];
}
// Compute input gradients
inputGradients[b][t] = new Array(this.featureSize);
// Gradient through normalization
let dNormSum = 0;
let dVarSum = 0;
for (let i = 0; i < this.featureSize; i++) {
const dNorm = outputGradients[b][t][i] * this.gamma[i];
dNormSum += dNorm;
dVarSum += dNorm * (this.cachedInputs[b][t][i] - mean);
}
const dVar = dVarSum * -0.5 * Math.pow(variance + this.epsilon, -1.5);
const dMean = -stdInv * dNormSum + dVar * -2.0 / this.featureSize *
(this.cachedInputs[b][t].reduce((a, v) => a + v - mean, 0));
for (let i = 0; i < this.featureSize; i++) {
const dNorm = outputGradients[b][t][i] * this.gamma[i];
inputGradients[b][t][i] = dNorm * stdInv +
dVar * 2.0 * (this.cachedInputs[b][t][i] - mean) / this.featureSize +
dMean / this.featureSize;
}
}
}
// Update parameters
for (let i = 0; i < this.featureSize; i++) {
this.gamma[i] -= learningRate * gammaGrad[i];
this.beta[i] -= learningRate * betaGrad[i];
}
return inputGradients;
}
}
Gamma and beta are shared across all positions and all sequences, so their gradients accumulate from everywhere.
Batched Attention
Attention is where batching pays off most. Each sequence in the batch attends independently, there's no cross-sequence attention. We just run the same attention computation for each batch element:
class MultiHeadAttention {
numHeads = 0;
headDim = 0;
embeddingDim = 0;
queryWeights = null;
keyWeights = null;
valueWeights = null;
outputWeights = null;
// Cached for backward pass
cachedInputs = null;
cachedQueries = null;
cachedKeys = null;
cachedValues = null;
cachedAttentionWeights = null;
constructor(embeddingDim, numHeads) {
this.embeddingDim = embeddingDim;
this.numHeads = numHeads;
this.headDim = embeddingDim / numHeads;
const scale = Math.sqrt(2.0 / embeddingDim);
this.queryWeights = this.#initializeWeights(embeddingDim, embeddingDim, scale);
this.keyWeights = this.#initializeWeights(embeddingDim, embeddingDim, scale);
this.valueWeights = this.#initializeWeights(embeddingDim, embeddingDim, scale);
this.outputWeights = this.#initializeWeights(embeddingDim, embeddingDim, scale);
}
#initializeWeights(inputSize, outputSize, scale) {
const weights = new Array(inputSize);
for (let i = 0; i < inputSize; i++) {
weights[i] = new Array(outputSize);
for (let j = 0; j < outputSize; j++) {
weights[i][j] = (Math.random() * 2 - 1) * scale;
}
}
return weights;
}
forward(inputs) {
// inputs: [batchSize][seqLength][embeddingDim]
const batchSize = inputs.length;
const seqLength = inputs[0].length;
this.cachedInputs = inputs;
// Project to Q, K, V for entire batch
const queries = this.#batchMatmul(inputs, this.queryWeights);
const keys = this.#batchMatmul(inputs, this.keyWeights);
const values = this.#batchMatmul(inputs, this.valueWeights);
this.cachedQueries = queries;
this.cachedKeys = keys;
this.cachedValues = values;
// Compute attention for each batch element
const scale = 1.0 / Math.sqrt(this.headDim);
this.cachedAttentionWeights = new Array(batchSize);
const attentionOutputs = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
this.cachedAttentionWeights[b] = new Array(this.numHeads);
attentionOutputs[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
attentionOutputs[b][t] = new Array(this.embeddingDim).fill(0);
}
// Process each head
for (let h = 0; h < this.numHeads; h++) {
const headStart = h * this.headDim;
this.cachedAttentionWeights[b][h] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
// Compute attention scores for position t
const scores = new Array(seqLength);
for (let s = 0; s <= t; s++) { // Causal: only attend to past
let score = 0;
for (let d = 0; d < this.headDim; d++) {
score += queries[b][t][headStart + d] * keys[b][s][headStart + d];
}
scores[s] = score * scale;
}
// Mask future positions
for (let s = t + 1; s < seqLength; s++) {
scores[s] = -Infinity;
}
// Softmax
const maxScore = Math.max(...scores.filter(s => s !== -Infinity));
let sumExp = 0;
const expScores = new Array(seqLength);
for (let s = 0; s < seqLength; s++) {
expScores[s] = scores[s] === -Infinity ? 0 : Math.exp(scores[s] - maxScore);
sumExp += expScores[s];
}
const attentionWeights = new Array(seqLength);
for (let s = 0; s < seqLength; s++) {
attentionWeights[s] = expScores[s] / sumExp;
}
this.cachedAttentionWeights[b][h][t] = attentionWeights;
// Apply attention to values
for (let s = 0; s < seqLength; s++) {
for (let d = 0; d < this.headDim; d++) {
attentionOutputs[b][t][headStart + d] +=
attentionWeights[s] * values[b][s][headStart + d];
}
}
}
}
}
// Output projection
const output = this.#batchMatmul(attentionOutputs, this.outputWeights);
return output;
}
#batchMatmul(inputs, weights) {
// inputs: [batchSize][seqLength][inputDim]
// weights: [inputDim][outputDim]
// output: [batchSize][seqLength][outputDim]
const batchSize = inputs.length;
const seqLength = inputs[0].length;
const inputDim = weights.length;
const outputDim = weights[0].length;
const output = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
output[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
output[b][t] = new Array(outputDim).fill(0);
for (let i = 0; i < inputDim; i++) {
for (let j = 0; j < outputDim; j++) {
output[b][t][j] += inputs[b][t][i] * weights[i][j];
}
}
}
}
return output;
}
backward(outputGradients, learningRate) {
const batchSize = outputGradients.length;
const seqLength = outputGradients[0].length;
// Gradient through output projection
const attentionGradients = this.#batchMatmulBackward(
outputGradients, this.cachedInputs, this.outputWeights, learningRate, 'output'
);
// Gradient through attention mechanism
const queryGrads = new Array(batchSize);
const keyGrads = new Array(batchSize);
const valueGrads = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
queryGrads[b] = new Array(seqLength);
keyGrads[b] = new Array(seqLength);
valueGrads[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
queryGrads[b][t] = new Array(this.embeddingDim).fill(0);
keyGrads[b][t] = new Array(this.embeddingDim).fill(0);
valueGrads[b][t] = new Array(this.embeddingDim).fill(0);
}
const scale = 1.0 / Math.sqrt(this.headDim);
for (let h = 0; h < this.numHeads; h++) {
const headStart = h * this.headDim;
for (let t = 0; t < seqLength; t++) {
const attnWeights = this.cachedAttentionWeights[b][h][t];
// Gradient w.r.t. values
for (let s = 0; s <= t; s++) {
for (let d = 0; d < this.headDim; d++) {
valueGrads[b][s][headStart + d] +=
attnWeights[s] * attentionGradients[b][t][headStart + d];
}
}
// Gradient w.r.t. attention weights
const dAttnWeights = new Array(seqLength).fill(0);
for (let s = 0; s <= t; s++) {
for (let d = 0; d < this.headDim; d++) {
dAttnWeights[s] += attentionGradients[b][t][headStart + d] *
this.cachedValues[b][s][headStart + d];
}
}
// Gradient through softmax
let dotProduct = 0;
for (let s = 0; s <= t; s++) {
dotProduct += dAttnWeights[s] * attnWeights[s];
}
const dScores = new Array(seqLength).fill(0);
for (let s = 0; s <= t; s++) {
dScores[s] = attnWeights[s] * (dAttnWeights[s] - dotProduct) * scale;
}
// Gradient w.r.t. queries and keys
for (let s = 0; s <= t; s++) {
for (let d = 0; d < this.headDim; d++) {
queryGrads[b][t][headStart + d] +=
dScores[s] * this.cachedKeys[b][s][headStart + d];
keyGrads[b][s][headStart + d] +=
dScores[s] * this.cachedQueries[b][t][headStart + d];
}
}
}
}
}
// Gradient through Q, K, V projections
const inputGrads1 = this.#batchMatmulBackward(
queryGrads, this.cachedInputs, this.queryWeights, learningRate, 'query'
);
const inputGrads2 = this.#batchMatmulBackward(
keyGrads, this.cachedInputs, this.keyWeights, learningRate, 'key'
);
const inputGrads3 = this.#batchMatmulBackward(
valueGrads, this.cachedInputs, this.valueWeights, learningRate, 'value'
);
// Sum input gradients from all three projections
const inputGradients = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
inputGradients[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
inputGradients[b][t] = new Array(this.embeddingDim);
for (let d = 0; d < this.embeddingDim; d++) {
inputGradients[b][t][d] = inputGrads1[b][t][d] +
inputGrads2[b][t][d] + inputGrads3[b][t][d];
}
}
}
return inputGradients;
}
#batchMatmulBackward(outputGrads, inputs, weights, learningRate, which) {
const batchSize = outputGrads.length;
const seqLength = outputGrads[0].length;
const inputDim = weights.length;
const outputDim = weights[0].length;
// Gradient w.r.t. inputs
const inputGrads = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
inputGrads[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
inputGrads[b][t] = new Array(inputDim).fill(0);
for (let i = 0; i < inputDim; i++) {
for (let j = 0; j < outputDim; j++) {
inputGrads[b][t][i] += outputGrads[b][t][j] * weights[i][j];
}
}
}
}
// Gradient w.r.t. weights (accumulate across batch)
for (let i = 0; i < inputDim; i++) {
for (let j = 0; j < outputDim; j++) {
let grad = 0;
for (let b = 0; b < batchSize; b++) {
for (let t = 0; t < seqLength; t++) {
grad += inputs[b][t][i] * outputGrads[b][t][j];
}
}
weights[i][j] -= learningRate * grad;
}
}
return inputGrads;
}
}
The attention mechanism is fundamentally the same. Each position attends to all previous positions (causal masking). We just do it independently for each sequence in the batch.
Batched MLP
The MLP block applies the same feed-forward network to each position independently:
class MLPBlock {
weights1 = null;
bias1 = null;
weights2 = null;
bias2 = null;
inputDim = 0;
hiddenDim = 0;
cachedInputs = null;
cachedHidden = null;
constructor(embeddingDim) {
this.inputDim = embeddingDim;
this.hiddenDim = embeddingDim * 4;
const scale1 = Math.sqrt(2.0 / embeddingDim);
const scale2 = Math.sqrt(2.0 / this.hiddenDim);
this.weights1 = this.#initWeights(embeddingDim, this.hiddenDim, scale1);
this.bias1 = new Array(this.hiddenDim).fill(0);
this.weights2 = this.#initWeights(this.hiddenDim, embeddingDim, scale2);
this.bias2 = new Array(embeddingDim).fill(0);
}
#initWeights(inputSize, outputSize, scale) {
const weights = new Array(inputSize);
for (let i = 0; i < inputSize; i++) {
weights[i] = new Array(outputSize);
for (let j = 0; j < outputSize; j++) {
weights[i][j] = (Math.random() * 2 - 1) * scale;
}
}
return weights;
}
forward(inputs) {
// inputs: [batchSize][seqLength][embeddingDim]
const batchSize = inputs.length;
const seqLength = inputs[0].length;
this.cachedInputs = inputs;
this.cachedHidden = new Array(batchSize);
const output = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
this.cachedHidden[b] = new Array(seqLength);
output[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
// First linear layer
const hidden = new Array(this.hiddenDim);
for (let h = 0; h < this.hiddenDim; h++) {
let sum = this.bias1[h];
for (let i = 0; i < this.inputDim; i++) {
sum += inputs[b][t][i] * this.weights1[i][h];
}
// GELU activation
hidden[h] = sum * 0.5 * (1 + Math.tanh(
Math.sqrt(2 / Math.PI) * (sum + 0.044715 * sum * sum * sum)
));
}
this.cachedHidden[b][t] = hidden;
// Second linear layer
output[b][t] = new Array(this.inputDim);
for (let o = 0; o < this.inputDim; o++) {
let sum = this.bias2[o];
for (let h = 0; h < this.hiddenDim; h++) {
sum += hidden[h] * this.weights2[h][o];
}
output[b][t][o] = sum;
}
}
}
return output;
}
backward(outputGradients, learningRate) {
const batchSize = outputGradients.length;
const seqLength = outputGradients[0].length;
const inputGradients = new Array(batchSize);
// Accumulate weight/bias gradients
const weights2Grad = new Array(this.hiddenDim);
for (let h = 0; h < this.hiddenDim; h++) {
weights2Grad[h] = new Array(this.inputDim).fill(0);
}
const bias2Grad = new Array(this.inputDim).fill(0);
const weights1Grad = new Array(this.inputDim);
for (let i = 0; i < this.inputDim; i++) {
weights1Grad[i] = new Array(this.hiddenDim).fill(0);
}
const bias1Grad = new Array(this.hiddenDim).fill(0);
for (let b = 0; b < batchSize; b++) {
inputGradients[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
// Gradient through second linear layer
const hiddenGrad = new Array(this.hiddenDim).fill(0);
for (let h = 0; h < this.hiddenDim; h++) {
for (let o = 0; o < this.inputDim; o++) {
hiddenGrad[h] += outputGradients[b][t][o] * this.weights2[h][o];
weights2Grad[h][o] += this.cachedHidden[b][t][h] * outputGradients[b][t][o];
}
}
for (let o = 0; o < this.inputDim; o++) {
bias2Grad[o] += outputGradients[b][t][o];
}
// Gradient through GELU
const preGelu = new Array(this.hiddenDim);
for (let h = 0; h < this.hiddenDim; h++) {
// Recompute pre-activation
let sum = this.bias1[h];
for (let i = 0; i < this.inputDim; i++) {
sum += this.cachedInputs[b][t][i] * this.weights1[i][h];
}
preGelu[h] = sum;
}
const geluGrad = new Array(this.hiddenDim);
for (let h = 0; h < this.hiddenDim; h++) {
const x = preGelu[h];
const cdf = 0.5 * (1 + Math.tanh(
Math.sqrt(2 / Math.PI) * (x + 0.044715 * x * x * x)
));
const pdf = Math.exp(-0.5 * x * x) / Math.sqrt(2 * Math.PI);
geluGrad[h] = hiddenGrad[h] * (cdf + x * pdf);
}
// Gradient through first linear layer
inputGradients[b][t] = new Array(this.inputDim).fill(0);
for (let i = 0; i < this.inputDim; i++) {
for (let h = 0; h < this.hiddenDim; h++) {
inputGradients[b][t][i] += geluGrad[h] * this.weights1[i][h];
weights1Grad[i][h] += this.cachedInputs[b][t][i] * geluGrad[h];
}
}
for (let h = 0; h < this.hiddenDim; h++) {
bias1Grad[h] += geluGrad[h];
}
}
}
// Update weights
for (let h = 0; h < this.hiddenDim; h++) {
for (let o = 0; o < this.inputDim; o++) {
this.weights2[h][o] -= learningRate * weights2Grad[h][o];
}
}
for (let o = 0; o < this.inputDim; o++) {
this.bias2[o] -= learningRate * bias2Grad[o];
}
for (let i = 0; i < this.inputDim; i++) {
for (let h = 0; h < this.hiddenDim; h++) {
this.weights1[i][h] -= learningRate * weights1Grad[i][h];
}
}
for (let h = 0; h < this.hiddenDim; h++) {
this.bias1[h] -= learningRate * bias1Grad[h];
}
return inputGradients;
}
}
Batched Transformer Block
The transformer block combines attention, MLP, layer norms, and residual connections:
class TransformerBlock {
attention = null;
mlp = null;
layerNorm1 = null;
layerNorm2 = null;
embeddingDim = 0;
constructor(embeddingDim, numHeads) {
this.embeddingDim = embeddingDim;
this.attention = new MultiHeadAttention(embeddingDim, numHeads);
this.mlp = new MLPBlock(embeddingDim);
this.layerNorm1 = new LayerNormalization(embeddingDim);
this.layerNorm2 = new LayerNormalization(embeddingDim);
}
forward(inputs) {
// inputs: [batchSize][seqLength][embeddingDim]
const batchSize = inputs.length;
const seqLength = inputs[0].length;
// Pre-norm architecture: norm -> attention -> residual
const normed1 = this.layerNorm1.forward(inputs);
const attended = this.attention.forward(normed1);
// Residual connection
const residual1 = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
residual1[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
residual1[b][t] = new Array(this.embeddingDim);
for (let d = 0; d < this.embeddingDim; d++) {
residual1[b][t][d] = inputs[b][t][d] + attended[b][t][d];
}
}
}
// Pre-norm -> MLP -> residual
const normed2 = this.layerNorm2.forward(residual1);
const mlpOutput = this.mlp.forward(normed2);
// Residual connection
const output = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
output[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
output[b][t] = new Array(this.embeddingDim);
for (let d = 0; d < this.embeddingDim; d++) {
output[b][t][d] = residual1[b][t][d] + mlpOutput[b][t][d];
}
}
}
return output;
}
backward(outputGradients, learningRate) {
const batchSize = outputGradients.length;
const seqLength = outputGradients[0].length;
// Gradient flows through residual (just copy)
const mlpGrad = this.mlp.backward(outputGradients, learningRate);
const norm2Grad = this.layerNorm2.backward(mlpGrad, learningRate);
// Add gradients from residual path
const residual1Grad = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
residual1Grad[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
residual1Grad[b][t] = new Array(this.embeddingDim);
for (let d = 0; d < this.embeddingDim; d++) {
residual1Grad[b][t][d] = outputGradients[b][t][d] + norm2Grad[b][t][d];
}
}
}
// Gradient through attention path
const attentionGrad = this.attention.backward(residual1Grad, learningRate);
const norm1Grad = this.layerNorm1.backward(attentionGrad, learningRate);
// Add gradients from residual path
const inputGradients = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
inputGradients[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
inputGradients[b][t] = new Array(this.embeddingDim);
for (let d = 0; d < this.embeddingDim; d++) {
inputGradients[b][t][d] = residual1Grad[b][t][d] + norm1Grad[b][t][d];
}
}
}
return inputGradients;
}
}
Batched Output Layer
The output layer projects to vocabulary size and computes softmax:
class OutputLayer {
weights = null;
bias = null;
inputDim = 0;
vocabSize = 0;
cachedInputs = null;
cachedProbs = null;
constructor(embeddingDim, vocabSize) {
this.inputDim = embeddingDim;
this.vocabSize = vocabSize;
const scale = Math.sqrt(2.0 / embeddingDim);
this.weights = new Array(embeddingDim);
for (let i = 0; i < embeddingDim; i++) {
this.weights[i] = new Array(vocabSize);
for (let j = 0; j < vocabSize; j++) {
this.weights[i][j] = (Math.random() * 2 - 1) * scale;
}
}
this.bias = new Array(vocabSize).fill(0);
}
forward(inputs) {
// inputs: [batchSize][seqLength][embeddingDim]
const batchSize = inputs.length;
const seqLength = inputs[0].length;
this.cachedInputs = inputs;
this.cachedProbs = new Array(batchSize);
const output = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
output[b] = new Array(seqLength);
this.cachedProbs[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
// Linear projection
const logits = new Array(this.vocabSize);
for (let v = 0; v < this.vocabSize; v++) {
let sum = this.bias[v];
for (let i = 0; i < this.inputDim; i++) {
sum += inputs[b][t][i] * this.weights[i][v];
}
logits[v] = sum;
}
// Softmax
const maxLogit = Math.max(...logits);
let sumExp = 0;
const probs = new Array(this.vocabSize);
for (let v = 0; v < this.vocabSize; v++) {
probs[v] = Math.exp(logits[v] - maxLogit);
sumExp += probs[v];
}
for (let v = 0; v < this.vocabSize; v++) {
probs[v] /= sumExp;
}
output[b][t] = probs;
this.cachedProbs[b][t] = probs;
}
}
return output;
}
backward(targetTokens, learningRate) {
// targetTokens: [batchSize][seqLength]
const batchSize = targetTokens.length;
const seqLength = targetTokens[0].length;
const inputGradients = new Array(batchSize);
// Accumulate weight gradients
const weightsGrad = new Array(this.inputDim);
for (let i = 0; i < this.inputDim; i++) {
weightsGrad[i] = new Array(this.vocabSize).fill(0);
}
const biasGrad = new Array(this.vocabSize).fill(0);
for (let b = 0; b < batchSize; b++) {
inputGradients[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
// Gradient of cross-entropy + softmax
const grad = new Array(this.vocabSize);
for (let v = 0; v < this.vocabSize; v++) {
grad[v] = this.cachedProbs[b][t][v];
}
grad[targetTokens[b][t]] -= 1;
// Gradient w.r.t. inputs
inputGradients[b][t] = new Array(this.inputDim).fill(0);
for (let i = 0; i < this.inputDim; i++) {
for (let v = 0; v < this.vocabSize; v++) {
inputGradients[b][t][i] += grad[v] * this.weights[i][v];
weightsGrad[i][v] += this.cachedInputs[b][t][i] * grad[v];
}
}
for (let v = 0; v < this.vocabSize; v++) {
biasGrad[v] += grad[v];
}
}
}
// Update weights
for (let i = 0; i < this.inputDim; i++) {
for (let v = 0; v < this.vocabSize; v++) {
this.weights[i][v] -= learningRate * weightsGrad[i][v];
}
}
for (let v = 0; v < this.vocabSize; v++) {
this.bias[v] -= learningRate * biasGrad[v];
}
return inputGradients;
}
computeLoss(probs, targetTokens) {
// Average cross-entropy loss across batch and sequence
const batchSize = probs.length;
const seqLength = probs[0].length;
let totalLoss = 0;
for (let b = 0; b < batchSize; b++) {
for (let t = 0; t < seqLength; t++) {
const targetProb = probs[b][t][targetTokens[b][t]];
totalLoss += -Math.log(targetProb + 1e-10);
}
}
return totalLoss / (batchSize * seqLength);
}
}
The Complete Batched GPT
Now we wire it all together:
class GabGPT {
embedding = null;
blocks = null;
output = null;
finalNorm = null;
vocabSize = 0;
constructor(vocabSize, embeddingDim, numHeads, numBlocks, maxSeqLength) {
this.vocabSize = vocabSize;
this.embedding = new PositionalEmbeddingLayer(vocabSize, embeddingDim, maxSeqLength);
this.blocks = new Array(numBlocks);
for (let i = 0; i < numBlocks; i++) {
this.blocks[i] = new TransformerBlock(embeddingDim, numHeads);
}
this.finalNorm = new LayerNormalization(embeddingDim);
this.output = new OutputLayer(embeddingDim, vocabSize);
}
forward(inputTokens) {
// inputTokens: [batchSize][seqLength]
let hidden = this.embedding.forward(inputTokens);
for (let i = 0; i < this.blocks.length; i++) {
hidden = this.blocks[i].forward(hidden);
}
hidden = this.finalNorm.forward(hidden);
return this.output.forward(hidden);
}
backward(targetTokens, learningRate) {
let gradients = this.output.backward(targetTokens, learningRate);
gradients = this.finalNorm.backward(gradients, learningRate);
for (let i = this.blocks.length - 1; i >= 0; i--) {
gradients = this.blocks[i].backward(gradients, learningRate);
}
this.embedding.backward(gradients, learningRate);
}
train(inputTokens, targetTokens, learningRate) {
const probs = this.forward(inputTokens);
const loss = this.output.computeLoss(probs, targetTokens);
this.backward(targetTokens, learningRate);
return loss;
}
generate(promptTokens, maxLength) {
// Generation still works on single sequences
// Wrap in batch dimension
let tokens = promptTokens.slice();
for (let i = 0; i < maxLength; i++) {
const batchedInput = [tokens]; // Batch size 1
const probs = this.forward(batchedInput);
const lastProbs = probs[0][tokens.length - 1];
const nextToken = this.#sampleFromDistribution(lastProbs);
tokens.push(nextToken);
}
return tokens;
}
#sampleFromDistribution(probs) {
const random = Math.random();
let cumulative = 0;
for (let i = 0; i < probs.length; i++) {
cumulative += probs[i];
if (random < cumulative) {
return i;
}
}
return probs.length - 1;
}
}
Training with Batches
The training loop now groups sequences into batches:
// Create batches from training data
function createBatches(tokens, batchSize, seqLength) {
const batches = [];
const numSequences = Math.floor((tokens.length - 1) / seqLength);
for (let batchStart = 0; batchStart < numSequences; batchStart += batchSize) {
const batchEnd = Math.min(batchStart + batchSize, numSequences);
const inputBatch = [];
const targetBatch = [];
for (let i = batchStart; i < batchEnd; i++) {
const start = i * seqLength;
inputBatch.push(tokens.slice(start, start + seqLength));
targetBatch.push(tokens.slice(start + 1, start + seqLength + 1));
}
batches.push({ inputs: inputBatch, targets: targetBatch });
}
return batches;
}
// Training loop
const batchSize = 4;
const seqLength = 32;
const batches = createBatches(tokens, batchSize, seqLength);
for (let epoch = 0; epoch < epochs; epoch++) {
let totalLoss = 0;
for (const batch of batches) {
const loss = model.train(batch.inputs, batch.targets, learningRate);
totalLoss += loss;
}
console.log(`Epoch ${epoch}: Loss = ${(totalLoss / batches.length).toFixed(4)}`);
}
Batching gives you two benefits: faster training (less overhead per sequence) and smoother gradients (noise averages out across examples).
Pre-Training
You've already been doing pre-training. The training loop from Part 4 that predicts the next token on raw text, that's pre-training. Pre-training is self-supervised learning. There are no human labels. The "label" for each position is simply the next token in the sequence. The model learns by trying to predict what comes next, over and over, on massive amounts of text.
Input: "The cat sat on the"
Target: "cat sat on the mat"
^ ^ ^ ^ ^
Predict each of these from the previous tokens
What does the model learn from this? Everything it needs to predict well:
- Grammar and syntax ("the" is often followed by a noun)
- Facts ("Paris is the capital of...")
- Reasoning patterns (if A then B, therefore...)
- Style and tone (formal text follows certain patterns)
The model doesn't "know" it's learning these things. It's just minimizing prediction error. But to predict the next token in "The capital of France is", you'd better know geography.
Pre-training is where most of the compute goes. It creates the foundation—a model that understands language. Everything after is refinement.
Supervised Fine-Tuning
Pre-trained models have a problem: they complete text, but they don't converse. Ask "What is 2+2?" and you might get "What is 2+3? What is 2+4?" because the model is pattern-matching against math worksheets it saw during pre-training. Supervised Fine-Tuning (SFT) teaches the model to follow a conversation format. It's the same training algorithm, just with different data.
Adding Conversation Tokens
First, we need special tokens to mark conversation structure:
// Add to your tokenizer
tokenizer.reserveToken("<|user|>");
tokenizer.reserveToken("<|assistant|>");
tokenizer.reserveToken("<|end|>");
tokenizer.reserveToken("<|pad|>"); // For batching variable-length conversations
These tokens will appear in every training example, so the model will learn their meaning from context. SFT data is just conversations formatted as text:
const conversations = [
{
user: "What is the capital of France?",
assistant: "The capital of France is Paris."
},
{
user: "Who wrote Romeo and Juliet?",
assistant: "William Shakespeare wrote Romeo and Juliet."
},
{
user: "What is 2 + 2?",
assistant: "2 + 2 equals 4."
}
];
function formatConversation(conv) {
return `<|user|>${conv.user}<|end|><|assistant|>${conv.assistant}<|end|>`;
}
// Convert all conversations to training data
const sftData = [];
for (const conv of conversations) {
const text = formatConversation(conv);
const tokens = tokenizer.encode(text);
sftData.push(tokens);
}
When formatted, a conversation looks like:
<|user|>What is 2 + 2?<|end|><|assistant|>2 + 2 equals 4.<|end|>
The model learns: after <|assistant|>, generate a helpful response. After <|end|>, stop.
Padding
There's a problem with batching SFT data. Each conversation has a different length:
Conversation 1: "<|user|>Hi<|end|><|assistant|>Hello!<|end|>" -> 12 tokens
Conversation 2: "<|user|>What is the capital of France?<|end|>..." -> 28 tokens
Conversation 3: "<|user|>Why?<|end|><|assistant|>Because.<|end|>" -> 10 tokens
Our batched code assumes rectangular arrays—every sequence in a batch has the same length. How do we handle this?
Padding is optional, but we will implement it anyway. JavaScript arrays can be jagged. Unlike PyTorch tensors, there's no requirement that batch[0].length === batch[1].length. We could rewrite every layer to handle variable-length sequences:
// Hypothetically, without padding:
for (let b = 0; b < batchSize; b++) {
const seqLength = inputs[b].length; // Each sequence has its own length
for (let t = 0; t < seqLength; t++) {
// ...
}
}
This would actually work. Each sequence would process independently with its own length.
So why bother with padding? Three reasons:
- It's what real frameworks do. PyTorch and TensorFlow require rectangular tensors. Learning padding prepares you for production tools.
- GPU parallelism. Real training runs on GPUs where rectangular batches enable efficient parallel computation. Jagged arrays can't be parallelized the same way.
- Simpler code. Handling variable lengths means checking
inputs[b].lengtheverywhere. Easy to introduce bugs.
We'll implement padding to match industry practice. First, add a padding token to the vocabulary:
tokenizer.reserveToken("<|pad|>");
Padding a Batch
This function takes variable-length sequences and pads them to match the longest:
function padBatch(sequences, padTokenId) {
// Find the longest sequence
let maxLen = 0;
for (let i = 0; i < sequences.length; i++) {
if (sequences[i].length > maxLen) {
maxLen = sequences[i].length;
}
}
const paddedSequences = new Array(sequences.length);
const paddingMask = new Array(sequences.length); // 1 = real token, 0 = padding
for (let i = 0; i < sequences.length; i++) {
const seq = sequences[i];
const padCount = maxLen - seq.length;
// Copy original sequence and add padding
paddedSequences[i] = new Array(maxLen);
paddingMask[i] = new Array(maxLen);
for (let j = 0; j < seq.length; j++) {
paddedSequences[i][j] = seq[j];
paddingMask[i][j] = 1; // Real token
}
for (let j = seq.length; j < maxLen; j++) {
paddedSequences[i][j] = padTokenId;
paddingMask[i][j] = 0; // Padding
}
}
return { sequences: paddedSequences, mask: paddingMask };
}
The function returns both the padded sequences and a mask indicating which positions are real (1) vs padding (0).
Attention with Padding Mask
Padding tokens shouldn't participate in attention. A real token shouldn't attend to padding, and padding shouldn't attend to anything. We need to modify the attention layer to accept a padding mask and combine it with the causal mask:
// In MultiHeadAttention class
forward(inputs, paddingMask = null) {
// inputs: [batchSize][seqLength][embeddingDim]
// paddingMask: [batchSize][seqLength] - 1 for real, 0 for padding
const batchSize = inputs.length;
const seqLength = inputs[0].length;
// ... Q, K, V projection code stays the same ...
for (let b = 0; b < batchSize; b++) {
// ... head loop setup ...
for (let t = 0; t < seqLength; t++) {
// Skip if this position is padding
if (paddingMask && paddingMask[b][t] === 0) {
// Output zeros for padding positions
for (let d = 0; d < this.embeddingDim; d++) {
attentionOutputs[b][t][d] = 0;
}
continue;
}
const scores = new Array(seqLength);
for (let s = 0; s < seqLength; s++) {
// Can't attend to future (causal)
if (s > t) {
scores[s] = -Infinity;
}
// Can't attend to padding
else if (paddingMask && paddingMask[b][s] === 0) {
scores[s] = -Infinity;
}
else {
// Compute attention score normally
let score = 0;
for (let d = 0; d < this.headDim; d++) {
score += queries[b][t][headStart + d] * keys[b][s][headStart + d];
}
scores[s] = score * scale;
}
}
// ... softmax and value aggregation stay the same ...
}
}
// ... output projection ...
}
The changes:
- If the current position is padding, output zeros and skip
- When computing attention scores, set padding positions to
-Infinityso they get zero weight after softmax
Padding Mask Through the Network
The padding mask needs to flow through the entire network. Update the transformer block and model:
// TransformerBlock
forward(inputs, paddingMask = null) {
const normed1 = this.layerNorm1.forward(inputs);
const attended = this.attention.forward(normed1, paddingMask); // Pass mask
// ... residual, MLP, etc ...
}
// GabGPT
forward(inputTokens, paddingMask = null) {
let hidden = this.embedding.forward(inputTokens);
for (let i = 0; i < this.blocks.length; i++) {
hidden = this.blocks[i].forward(hidden, paddingMask); // Pass mask
}
hidden = this.finalNorm.forward(hidden);
return this.output.forward(hidden);
}
What about layer normalization? Technically, padding positions shouldn't contribute to the mean and variance calculation. In practice, padding values are often zeros, so the effect is minimal. We'll accept this small imprecision to keep the code simpler. Production systems handle this more carefully.
Loss Masking
There's a chat refinement: we only care that the model generates good assistant responses. The user's text is given, we don't need to train on predicting it. Loss masking zeros out the loss on user tokens, focusing all learning on the assistant's output:
function createLossMask(tokens, tokenizer) {
// Get special token IDs
const userTokenId = tokenizer.encode("<|user|>")[0];
const assistantTokenId = tokenizer.encode("<|assistant|>")[0];
const endTokenId = tokenizer.encode("<|end|>")[0];
const padTokenId = tokenizer.encode("<|pad|>")[0];
const mask = new Array(tokens.length).fill(0);
let inAssistantTurn = false;
for (let i = 0; i < tokens.length; i++) {
// Padding tokens are never included in loss
if (tokens[i] === padTokenId) {
mask[i] = 0;
continue;
}
if (tokens[i] === assistantTokenId) {
inAssistantTurn = true;
} else if (tokens[i] === userTokenId) {
inAssistantTurn = false;
} else if (tokens[i] === endTokenId) {
// Include the end token in assistant's mask, then switch off
if (inAssistantTurn) {
mask[i] = 1;
}
inAssistantTurn = false;
}
if (inAssistantTurn) {
mask[i] = 1;
}
}
return mask;
}
For the example "What is 2 + 2?":
Tokens: <|user|> What is 2 + 2 ? <|end|> <|assistant|> 2 + 2 equals 4 . <|end|>
Mask: 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1
Only positions with mask=1 contribute to the loss.
Training with Loss Masking
Modify the output layer to accept a mask:
// In OutputLayer class
computeLossWithMask(probs, targetTokens, mask) {
const batchSize = probs.length;
const seqLength = probs[0].length;
let totalLoss = 0;
let count = 0;
for (let b = 0; b < batchSize; b++) {
for (let t = 0; t < seqLength; t++) {
if (mask[b][t] === 1) {
const targetProb = probs[b][t][targetTokens[b][t]];
totalLoss += -Math.log(targetProb + 1e-10);
count++;
}
}
}
return count > 0 ? totalLoss / count : 0;
}
backwardWithMask(targetTokens, mask, learningRate) {
const batchSize = targetTokens.length;
const seqLength = targetTokens[0].length;
const inputGradients = new Array(batchSize);
// Weight gradients (accumulated)
const weightsGrad = new Array(this.inputDim);
for (let i = 0; i < this.inputDim; i++) {
weightsGrad[i] = new Array(this.vocabSize).fill(0);
}
const biasGrad = new Array(this.vocabSize).fill(0);
for (let b = 0; b < batchSize; b++) {
inputGradients[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
if (mask[b][t] === 0) {
// No gradient for masked positions
inputGradients[b][t] = new Array(this.inputDim).fill(0);
continue;
}
// Same gradient computation as before
const grad = new Array(this.vocabSize);
for (let v = 0; v < this.vocabSize; v++) {
grad[v] = this.cachedProbs[b][t][v];
}
grad[targetTokens[b][t]] -= 1;
inputGradients[b][t] = new Array(this.inputDim).fill(0);
for (let i = 0; i < this.inputDim; i++) {
for (let v = 0; v < this.vocabSize; v++) {
inputGradients[b][t][i] += grad[v] * this.weights[i][v];
weightsGrad[i][v] += this.cachedInputs[b][t][i] * grad[v];
}
}
for (let v = 0; v < this.vocabSize; v++) {
biasGrad[v] += grad[v];
}
}
}
// Update weights
for (let i = 0; i < this.inputDim; i++) {
for (let v = 0; v < this.vocabSize; v++) {
this.weights[i][v] -= learningRate * weightsGrad[i][v];
}
}
for (let v = 0; v < this.vocabSize; v++) {
this.bias[v] -= learningRate * biasGrad[v];
}
return inputGradients;
}
The Complete SFT Training Loop
Now we combine padding and loss masking:
// Get pad token ID
const padTokenId = tokenizer.encode("<|pad|>")[0];
// Prepare SFT batches with padding
function createSFTBatches(sftData, batchSize, tokenizer, padTokenId) {
const batches = [];
for (let i = 0; i < sftData.length; i += batchSize) {
const batchEnd = Math.min(i + batchSize, sftData.length);
// Collect raw sequences for this batch
const rawInputs = [];
const rawTargets = [];
for (let j = i; j < batchEnd; j++) {
const tokens = sftData[j];
rawInputs.push(tokens.slice(0, -1)); // All but last
rawTargets.push(tokens.slice(1)); // All but first
}
// Pad inputs and targets to same length
const { sequences: paddedInputs, mask: paddingMask } = padBatch(rawInputs, padTokenId);
const { sequences: paddedTargets } = padBatch(rawTargets, padTokenId);
// Create loss masks (handles both assistant-only and padding)
const lossMasks = new Array(paddedTargets.length);
for (let j = 0; j < paddedTargets.length; j++) {
// Create mask on the padded target sequence
lossMasks[j] = createLossMask(paddedTargets[j], tokenizer);
}
batches.push({
inputs: paddedInputs,
targets: paddedTargets,
paddingMask: paddingMask, // For attention
lossMask: lossMasks // For loss computation
});
}
return batches;
}
const sftBatches = createSFTBatches(sftData, batchSize, tokenizer, padTokenId);
// SFT training loop
const sftEpochs = 100;
const sftLearningRate = 0.001; // Lower than pre-training!
for (let epoch = 0; epoch < sftEpochs; epoch++) {
let totalLoss = 0;
for (const batch of sftBatches) {
// Forward pass with padding mask
const probs = model.forward(batch.inputs, batch.paddingMask);
// Compute loss only on non-padded assistant tokens
const loss = model.output.computeLossWithMask(probs, batch.targets, batch.lossMask);
// Backward pass with loss mask
model.output.backwardWithMask(batch.targets, batch.lossMask, sftLearningRate);
// Backward through rest of network
let gradients = model.output.inputGradients;
gradients = model.finalNorm.backward(gradients, sftLearningRate);
for (let i = model.blocks.length - 1; i >= 0; i--) {
gradients = model.blocks[i].backward(gradients, sftLearningRate);
}
model.embedding.backward(gradients, sftLearningRate);
totalLoss += loss;
}
if (epoch % 10 === 0) {
console.log(`SFT Epoch ${epoch}: Loss = ${(totalLoss / sftBatches.length).toFixed(4)}`);
}
}
Notice the lower learning rate. SFT builds on pre-trained knowledge—we don't want to destroy what the model already learned.
The padding mask ensures attention doesn't leak into padding positions. The loss mask ensures we only train on assistant tokens (not user text, not padding). Together, they make SFT work correctly on batched, variable-length conversations.
User / Assistant Roles
The special tokens create implicit "roles." After seeing thousands of examples where helpful responses follow <|assistant|>, the model learns the pattern.
During generation, we format the prompt the same way:
function chat(model, tokenizer, userMessage) {
const prompt = `<|user|>${userMessage}<|end|><|assistant|>`;
const promptTokens = tokenizer.encode(prompt);
const generated = model.generate(promptTokens, 50);
const response = tokenizer.decode(generated);
// Extract just the assistant's response
const assistantStart = response.indexOf("<|assistant|>") + "<|assistant|>".length;
const assistantEnd = response.indexOf("<|end|>", assistantStart);
return response.slice(assistantStart, assistantEnd !== -1 ? assistantEnd : undefined);
}
// Usage
const response = chat(model, tokenizer, "What is the capital of France?");
console.log(response); // "The capital of France is Paris."
The model sees the familiar pattern—<|user|> followed by a question, then <|assistant|>—and generates what it learned comes next: a helpful answer.
Thinking
You've probably seen models that show their reasoning in <think> tags before giving an answer. This isn't a special architecture—it's just SFT with a different format.
const thinkingConversations = [
{
user: "What is 15 + 27?",
thinking: "I need to add 15 and 27. 15 + 27 = 42.",
assistant: "15 + 27 equals 42."
},
{
user: "Is a whale a fish?",
thinking: "Whales live in water, but they breathe air and nurse their young. These are mammal characteristics.",
assistant: "No, a whale is not a fish. Whales are mammals."
}
];
function formatThinkingConversation(conv) {
return `<|user|>${conv.user}<|end|><|think|>${conv.thinking}<|end|><|assistant|>${conv.assistant}<|end|>`;
}
Add <|think|> as a reserved token, format your training data with thinking sections, and the model learns to "think" before responding. The thinking isn't magic cognition—it's the model generating intermediate text that helps it produce better final answers. By writing out reasoning steps, the model conditions its output on that reasoning, leading to more accurate responses.
Thinking tokens work because transformers are autoregressive. Each generated token influences the next. Generating "15 + 27 = 42" gives the model something concrete to reference when generating the final answer.
Training Tricks
Our simple gradient descent works, but there are tricks that make training faster and more stable.
Learning Rate Warmup & Decay
Using the same learning rate throughout training is suboptimal. At the start, weights are random, large updates might send them in bad directions. At the end, we want to fine-tune carefully. Warmup starts with a tiny learning rate and increases it over the first few steps. Decay decreases it as training progresses.
function getLearningRate(step, warmupSteps, totalSteps, maxLR) {
if (step < warmupSteps) {
// Linear warmup: 0 -> maxLR
return maxLR * (step / warmupSteps);
}
// Cosine decay: maxLR -> 0
const progress = (step - warmupSteps) / (totalSteps - warmupSteps);
return maxLR * 0.5 * (1 + Math.cos(Math.PI * progress));
}
The cosine shape is smooth—no sudden changes that might destabilize training. Usage:
const warmupSteps = 100;
const totalSteps = epochs * batches.length;
const maxLR = 0.001;
let step = 0;
for (let epoch = 0; epoch < epochs; epoch++) {
for (const batch of batches) {
const lr = getLearningRate(step, warmupSteps, totalSteps, maxLR);
const loss = model.train(batch.inputs, batch.targets, lr);
step++;
}
}
Gradient Clipping
Sometimes gradients explode. They become huge, causing massive weight updates that destroy the model. This happens especially with deep networks or long sequences.
Gradient clipping caps gradient magnitude:
function clipGradient(gradient, maxValue) {
return Math.max(-maxValue, Math.min(maxValue, gradient));
}
Apply this wherever you compute gradients. For example, in a backward pass:
for (let i = 0; i < this.inputDim; i++) {
for (let h = 0; h < this.hiddenDim; h++) {
let grad = this.cachedInputs[b][t][i] * geluGrad[h];
grad = clipGradient(grad, 1.0); // Clip to [-1, 1]
weights1Grad[i][h] += grad;
}
}
A cleaner approach is to clip the global gradient norm, compute the total magnitude of all gradients, and scale them down if it exceeds a threshold. But per-value clipping is simpler and often sufficient.
The Adam Optimizer
Gradient descent uses the same learning rate for every parameter. Adam (Adaptive Moment Estimation) adapts the learning rate per-parameter based on gradient history. Adam tracks two quantities for each parameter:
- m (first moment): Running average of gradients (momentum)
- v (second moment): Running average of squared gradients (how "big" gradients have been)
Parameters with large gradients get smaller updates. Parameters with small, consistent gradients get larger updates. To use Adam, each layer needs to store optimizer state. We'll add this to our layers. First, let's create an Adam helper class to keep the logic centralized:
class AdamOptimizer {
learningRate = 0.001;
beta1 = 0.9; // Momentum decay
beta2 = 0.999; // Squared gradient decay
epsilon = 1e-8; // Prevents division by zero
step = 0;
constructor(learningRate = 0.001) {
this.learningRate = learningRate;
}
createState(shape) {
// Initialize m and v to zeros with the given shape
if (typeof shape === 'number') {
return {
m: new Array(shape).fill(0),
v: new Array(shape).fill(0)
};
} else if (shape.length === 2) {
const m = new Array(shape[0]);
const v = new Array(shape[0]);
for (let i = 0; i < shape[0]; i++) {
m[i] = new Array(shape[1]).fill(0);
v[i] = new Array(shape[1]).fill(0);
}
return { m, v };
}
}
update1D(params, grads, state) {
this.step++;
for (let i = 0; i < params.length; i++) {
// Update biased first moment estimate
state.m[i] = this.beta1 * state.m[i] + (1 - this.beta1) * grads[i];
// Update biased second moment estimate
state.v[i] = this.beta2 * state.v[i] + (1 - this.beta2) * grads[i] * grads[i];
// Bias correction
const mHat = state.m[i] / (1 - Math.pow(this.beta1, this.step));
const vHat = state.v[i] / (1 - Math.pow(this.beta2, this.step));
// Update parameter
params[i] -= this.learningRate * mHat / (Math.sqrt(vHat) + this.epsilon);
}
}
update2D(params, grads, state) {
this.step++;
for (let i = 0; i < params.length; i++) {
for (let j = 0; j < params[i].length; j++) {
state.m[i][j] = this.beta1 * state.m[i][j] + (1 - this.beta1) * grads[i][j];
state.v[i][j] = this.beta2 * state.v[i][j] + (1 - this.beta2) * grads[i][j] * grads[i][j];
const mHat = state.m[i][j] / (1 - Math.pow(this.beta1, this.step));
const vHat = state.v[i][j] / (1 - Math.pow(this.beta2, this.step));
params[i][j] -= this.learningRate * mHat / (Math.sqrt(vHat) + this.epsilon);
}
}
}
}
Now let's modify the MLP block to use Adam. The pattern is: store optimizer state, accumulate gradients, then call the optimizer to update:
class MLPBlock {
weights1 = null;
bias1 = null;
weights2 = null;
bias2 = null;
inputDim = 0;
hiddenDim = 0;
cachedInputs = null;
cachedHidden = null;
// Adam state
optimizer = null;
weights1State = null;
bias1State = null;
weights2State = null;
bias2State = null;
constructor(embeddingDim, optimizer = null) {
this.inputDim = embeddingDim;
this.hiddenDim = embeddingDim * 4;
this.optimizer = optimizer;
const scale1 = Math.sqrt(2.0 / embeddingDim);
const scale2 = Math.sqrt(2.0 / this.hiddenDim);
this.weights1 = this.#initWeights(embeddingDim, this.hiddenDim, scale1);
this.bias1 = new Array(this.hiddenDim).fill(0);
this.weights2 = this.#initWeights(this.hiddenDim, embeddingDim, scale2);
this.bias2 = new Array(embeddingDim).fill(0);
// Initialize Adam state if optimizer provided
if (optimizer) {
this.weights1State = optimizer.createState([embeddingDim, this.hiddenDim]);
this.bias1State = optimizer.createState(this.hiddenDim);
this.weights2State = optimizer.createState([this.hiddenDim, embeddingDim]);
this.bias2State = optimizer.createState(embeddingDim);
}
}
#initWeights(inputSize, outputSize, scale) {
const weights = new Array(inputSize);
for (let i = 0; i < inputSize; i++) {
weights[i] = new Array(outputSize);
for (let j = 0; j < outputSize; j++) {
weights[i][j] = (Math.random() * 2 - 1) * scale;
}
}
return weights;
}
forward(inputs) {
// Same as before
const batchSize = inputs.length;
const seqLength = inputs[0].length;
this.cachedInputs = inputs;
this.cachedHidden = new Array(batchSize);
const output = new Array(batchSize);
for (let b = 0; b < batchSize; b++) {
this.cachedHidden[b] = new Array(seqLength);
output[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
const hidden = new Array(this.hiddenDim);
for (let h = 0; h < this.hiddenDim; h++) {
let sum = this.bias1[h];
for (let i = 0; i < this.inputDim; i++) {
sum += inputs[b][t][i] * this.weights1[i][h];
}
// GELU
hidden[h] = sum * 0.5 * (1 + Math.tanh(
Math.sqrt(2 / Math.PI) * (sum + 0.044715 * sum * sum * sum)
));
}
this.cachedHidden[b][t] = hidden;
output[b][t] = new Array(this.inputDim);
for (let o = 0; o < this.inputDim; o++) {
let sum = this.bias2[o];
for (let h = 0; h < this.hiddenDim; h++) {
sum += hidden[h] * this.weights2[h][o];
}
output[b][t][o] = sum;
}
}
}
return output;
}
backward(outputGradients) {
const batchSize = outputGradients.length;
const seqLength = outputGradients[0].length;
const inputGradients = new Array(batchSize);
// Initialize gradient accumulators
const weights2Grad = new Array(this.hiddenDim);
for (let h = 0; h < this.hiddenDim; h++) {
weights2Grad[h] = new Array(this.inputDim).fill(0);
}
const bias2Grad = new Array(this.inputDim).fill(0);
const weights1Grad = new Array(this.inputDim);
for (let i = 0; i < this.inputDim; i++) {
weights1Grad[i] = new Array(this.hiddenDim).fill(0);
}
const bias1Grad = new Array(this.hiddenDim).fill(0);
for (let b = 0; b < batchSize; b++) {
inputGradients[b] = new Array(seqLength);
for (let t = 0; t < seqLength; t++) {
// Gradient through second linear
const hiddenGrad = new Array(this.hiddenDim).fill(0);
for (let h = 0; h < this.hiddenDim; h++) {
for (let o = 0; o < this.inputDim; o++) {
hiddenGrad[h] += outputGradients[b][t][o] * this.weights2[h][o];
weights2Grad[h][o] += this.cachedHidden[b][t][h] * outputGradients[b][t][o];
}
}
for (let o = 0; o < this.inputDim; o++) {
bias2Grad[o] += outputGradients[b][t][o];
}
// Gradient through GELU
const preGelu = new Array(this.hiddenDim);
for (let h = 0; h < this.hiddenDim; h++) {
let sum = this.bias1[h];
for (let i = 0; i < this.inputDim; i++) {
sum += this.cachedInputs[b][t][i] * this.weights1[i][h];
}
preGelu[h] = sum;
}
const geluGrad = new Array(this.hiddenDim);
for (let h = 0; h < this.hiddenDim; h++) {
const x = preGelu[h];
const cdf = 0.5 * (1 + Math.tanh(
Math.sqrt(2 / Math.PI) * (x + 0.044715 * x * x * x)
));
const pdf = Math.exp(-0.5 * x * x) / Math.sqrt(2 * Math.PI);
geluGrad[h] = hiddenGrad[h] * (cdf + x * pdf);
}
// Gradient through first linear
inputGradients[b][t] = new Array(this.inputDim).fill(0);
for (let i = 0; i < this.inputDim; i++) {
for (let h = 0; h < this.hiddenDim; h++) {
inputGradients[b][t][i] += geluGrad[h] * this.weights1[i][h];
weights1Grad[i][h] += this.cachedInputs[b][t][i] * geluGrad[h];
}
}
for (let h = 0; h < this.hiddenDim; h++) {
bias1Grad[h] += geluGrad[h];
}
}
}
// Update weights with Adam or SGD
if (this.optimizer) {
this.optimizer.update2D(this.weights2, weights2Grad, this.weights2State);
this.optimizer.update1D(this.bias2, bias2Grad, this.bias2State);
this.optimizer.update2D(this.weights1, weights1Grad, this.weights1State);
this.optimizer.update1D(this.bias1, bias1Grad, this.bias1State);
}
return inputGradients;
}
}
The key changes:
- Store optimizer state alongside weights
- Accumulate gradients instead of applying them immediately
- Call optimizer.update() at the end of backward()
Apply the same pattern to other layers: embedding, attention, output layer, layer normalization.
Using Adam in Training
// Create shared optimizer
const optimizer = new AdamOptimizer(0.001);
// Pass to model components
const model = new GabGPT(vocabSize, embeddingDim, numHeads, numBlocks, maxSeqLength, optimizer);
// Training loop no longer needs to pass learning rate to train()
for (let epoch = 0; epoch < epochs; epoch++) {
let totalLoss = 0;
for (const batch of batches) {
const loss = model.train(batch.inputs, batch.targets); // No LR argument
totalLoss += loss;
}
console.log(`Epoch ${epoch}: Loss = ${(totalLoss / batches.length).toFixed(4)}`);
}
With Adam, you can often use a fixed learning rate (like 0.001) without warmup. But the best results come from combining Adam with warmup and cosine decay—modify the optimizer's learning rate per step.
Putting It All Together
Let's train a small model from scratch through pre-training and SFT:
// ==========================================
// Complete Training Pipeline
// ==========================================
// 1. Initialize tokenizer
const tokenizer = new Tokenizer();
tokenizer.reserveToken("<|user|>");
tokenizer.reserveToken("<|assistant|>");
tokenizer.reserveToken("<|end|>");
tokenizer.reserveToken("<|pad|>");
tokenizer.reserveToken("<|think|>");
// 2. Pre-training data
const pretrainingText = `
The cat sat on the mat. The dog sat on the log.
The quick brown fox jumps over the lazy dog.
In the morning, the sun rises in the east.
In the evening, the sun sets in the west.
Paris is the capital of France.
London is the capital of the United Kingdom.
The Earth orbits around the Sun.
Water freezes at zero degrees Celsius.
`;
tokenizer.train(pretrainingText, 100);
const pretrainingTokens = tokenizer.encode(pretrainingText);
console.log(`Vocabulary size: ${tokenizer.getVocabSize()}`);
console.log(`Pre-training tokens: ${pretrainingTokens.length}`);
// Get pad token ID for later
const padTokenId = tokenizer.encode("<|pad|>")[0];
// 3. Create model with Adam optimizer
const embeddingDim = 64;
const numHeads = 4;
const numBlocks = 2;
const maxSeqLength = 64;
const optimizer = new AdamOptimizer(0.001);
const model = new GabGPT(
tokenizer.getVocabSize(),
embeddingDim,
numHeads,
numBlocks,
maxSeqLength,
optimizer
);
// 4. Pre-training
console.log("\n=== Pre-training ===");
const batchSize = 4;
const seqLength = 32;
const pretrainingBatches = createBatches(pretrainingTokens, batchSize, seqLength);
const pretrainingEpochs = 100;
for (let epoch = 0; epoch < pretrainingEpochs; epoch++) {
let totalLoss = 0;
for (const batch of pretrainingBatches) {
const loss = model.train(batch.inputs, batch.targets);
totalLoss += loss;
}
if (epoch % 20 === 0) {
console.log(`Epoch ${epoch}: Loss = ${(totalLoss / pretrainingBatches.length).toFixed(4)}`);
}
}
// 5. SFT data
const sftConversations = [
{ user: "What is the capital of France?", assistant: "The capital of France is Paris." },
{ user: "What is the capital of the UK?", assistant: "The capital of the United Kingdom is London." },
{ user: "When does the sun rise?", assistant: "The sun rises in the morning, in the east." },
{ user: "What does water do at zero degrees?", assistant: "Water freezes at zero degrees Celsius." },
];
const sftData = [];
for (const conv of sftConversations) {
const text = `<|user|>${conv.user}<|end|><|assistant|>${conv.assistant}<|end|>`;
sftData.push(tokenizer.encode(text));
}
// 6. SFT training with padding
console.log("\n=== Supervised Fine-Tuning ===");
optimizer.learningRate = 0.0001; // Lower for fine-tuning
const sftEpochs = 200;
// Create batches with padding
const sftBatches = createSFTBatches(sftData, batchSize, tokenizer, padTokenId);
for (let epoch = 0; epoch < sftEpochs; epoch++) {
let totalLoss = 0;
for (const batch of sftBatches) {
// Forward with padding mask
const probs = model.forward(batch.inputs, batch.paddingMask);
const loss = model.output.computeLossWithMask(probs, batch.targets, batch.lossMask);
// Backward with loss mask
model.output.backwardWithMask(batch.targets, batch.lossMask);
let gradients = model.output.inputGradients;
gradients = model.finalNorm.backward(gradients);
for (let i = model.blocks.length - 1; i >= 0; i--) {
gradients = model.blocks[i].backward(gradients);
}
model.embedding.backward(gradients);
totalLoss += loss;
}
if (epoch % 50 === 0) {
console.log(`Epoch ${epoch}: Loss = ${(totalLoss / sftBatches.length).toFixed(4)}`);
}
}
// 7. Test the chatbot
console.log("\n=== Testing ===");
const testQuestions = [
"What is the capital of France?",
"When does the sun set?",
];
for (const question of testQuestions) {
const response = chat(model, tokenizer, question);
console.log(`User: ${question}`);
console.log(`Assistant: ${response}\n`);
}