← Back to blog
9 min read

Grokking has two phases, and you can see the boundary

A component-freezing ablation study on modular division reveals that grokking isn't one process. It's two: infrastructure setup, then computational reorganization. The boundary between them is sharp and measurable.

Grokking is one of the strangest things neural networks do. First described by Power et al. [1], the phenomenon is this: a model memorizes its training data perfectly - 100% accuracy - and then sits there for thousands of steps, doing nothing on the test set. Zero generalization. And then, suddenly, it generalizes. Not gradually. Suddenly. Test accuracy jumps from near-zero to near-perfect in a few hundred steps.

The standard explanation is that training has multiple phases: memorization, then circuit formation, then cleanup (weight decay compresses the memorized solution into a generalizing one). This was established by Nanda et al. [2] and is well-accepted.

But I wanted to ask a more operational question. Not "what phases exist?" but "which parts of the network are needed during which phase?" If I freeze a component at one point in training, grokking dies. If I freeze the same component later, grokking survives. Where exactly is that boundary? And what does it tell us?

If I freeze a component at one point in training, grokking dies. If I freeze the same component later, grokking survives. Where exactly is that boundary?

The setup

I trained a small transformer (4 layers, 128 hidden, 820K parameters) on modular division: given a and b, predict (a / b) mod 97. The training set is 50% of all valid pairs. AdamW optimizer, learning rate 3e-4, weight decay 0.3, batch size 64.

This is a standard grokking setup. The model memorizes the training set by around step 1,500 and doesn't generalize until roughly step 15,000. Here's what that looks like:

Accuracy 0% 20% 40% 60% 80% 100% 0 5K 10K 15K 20K 25K 30K Training steps train: 100% test: ~2% gap: 98% sudden generalization Train accuracy Test accuracy

By step 2,000, the model has perfectly memorized every training example. But it has learned nothing generalizable. Test accuracy is 2%. The gap between what it knows and what it understands is 98 percentage points.

That gap persists for over ten thousand steps. Then it collapses.

Train acc − Test acc 0% 25% 50% 75% 100% 0 5K 10K 15K 20K 25K 30K Training steps Peak gap: 100% Perfect train, zero test Gap collapses to 0 at step ~19K

The question is: what is the network doing during those 13,000 steps of apparent silence? The test accuracy says "nothing." But that can't be true. Something must be changing internally for generalization to suddenly appear.

The experiment: freeze and continue

I saved checkpoints at two moments during training:

  • Step 7,000 — deep in the memorization plateau. Train accuracy is 100%, test accuracy is 2%. Generalization has not started.
  • Step 11,000 — still in the plateau, but closer to the grokking onset. Train accuracy still 100%, test accuracy still around 3-5%.

From each checkpoint, I loaded the model, froze one component (set its parameters to requires_grad = False), and continued training for 10,000 more steps with everything else unchanged. Same optimizer, same learning rate, same weight decay. Then I checked: did the model still grok?

I tested 11 interventions: freezing the entry layer, the middle layers, the exit layer, the exit attention, the exit MLP, all attention, all MLP, the output head, the embeddings, and removing weight decay.

The full experiment code (820K parameter model, ~20 minutes on any GPU):

python grokking_ablation.py 216 lines
"""
Ablation study: which components does grokking need?

Train a 4-layer transformer on modular division (a/b mod 97).
At two checkpoints (7K and 11K steps), freeze each component
and continue training to see if generalization still happens.

Runs in ~20 minutes on any GPU or Apple MPS.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json

DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
P = 97  # prime for modular arithmetic


def mod_inverse(b, p):
    return pow(b, p - 2, p)


class ModularDivisionDataset:
    def __init__(self, p, train=True, seed=42):
        self.p = p
        np.random.seed(seed)
        all_pairs = [(a, b) for a in range(p) for b in range(1, p)]
        np.random.shuffle(all_pairs)
        split = int(len(all_pairs) * 0.5)
        self.pairs = all_pairs[:split] if train else all_pairs[split:]
        self.pairs = [(a, b, (a * mod_inverse(b, p)) % p) for a, b in self.pairs]

    def get_batch(self, batch_size):
        indices = np.random.choice(len(self.pairs), min(batch_size, len(self.pairs)), replace=False)
        batch = [self.pairs[i] for i in indices]
        x = torch.tensor([[a, b, P] for a, b, c in batch], dtype=torch.long)
        y = torch.tensor([c for a, b, c in batch], dtype=torch.long)
        return x.to(DEVICE), y.to(DEVICE)


class TransformerLayer(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size)
        self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.mlp_fc1 = nn.Linear(hidden_size, 4 * hidden_size)
        self.mlp_fc2 = nn.Linear(4 * hidden_size, hidden_size)

    def forward(self, x):
        h = self.ln1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = x + attn_out
        h = self.ln2(x)
        x = x + self.mlp_fc2(F.gelu(self.mlp_fc1(h)))
        return x


class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=99, hidden_size=128, num_layers=4, num_heads=4):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(16, hidden_size)
        self.layers = nn.ModuleList([
            TransformerLayer(hidden_size, num_heads) for _ in range(num_layers)
        ])
        self.ln_f = nn.LayerNorm(hidden_size)
        self.head = nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, x):
        B, T = x.shape
        h = self.embed(x) + self.pos_embed(torch.arange(T, device=x.device))
        for layer in self.layers:
            h = layer(h)
        h = self.ln_f(h)
        return self.head(h[:, -1])


def apply_intervention(model, intervention_name):
    """Freeze specific components. Returns (trainable_params, total_params)."""

    if intervention_name == "baseline":
        pass
    elif intervention_name == "freeze_head":
        for p in model.head.parameters():
            p.requires_grad = False
    elif intervention_name == "freeze_embed":
        for p in model.embed.parameters():
            p.requires_grad = False
        for p in model.pos_embed.parameters():
            p.requires_grad = False
    elif intervention_name == "freeze_attn_all":
        for layer in model.layers:
            for p in layer.attn.parameters():
                p.requires_grad = False
            for p in layer.ln1.parameters():
                p.requires_grad = False
    elif intervention_name == "freeze_mlp_all":
        for layer in model.layers:
            for p in layer.mlp_fc1.parameters():
                p.requires_grad = False
            for p in layer.mlp_fc2.parameters():
                p.requires_grad = False
            for p in layer.ln2.parameters():
                p.requires_grad = False
    elif intervention_name == "freeze_exit_layer":
        for p in model.layers[-1].parameters():
            p.requires_grad = False
    elif intervention_name == "freeze_entry_layer":
        for p in model.layers[0].parameters():
            p.requires_grad = False
    elif intervention_name == "freeze_middle_layers":
        for p in model.layers[1].parameters():
            p.requires_grad = False
        for p in model.layers[2].parameters():
            p.requires_grad = False
    elif intervention_name == "freeze_exit_attn":
        for p in model.layers[-1].attn.parameters():
            p.requires_grad = False
    elif intervention_name == "freeze_exit_mlp":
        for p in model.layers[-1].mlp_fc1.parameters():
            p.requires_grad = False
        for p in model.layers[-1].mlp_fc2.parameters():
            p.requires_grad = False

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


def run_ablation(checkpoint_step, intervention_name, num_steps=10000, wd=0.3):
    """Load from checkpoint, apply intervention, train, return trajectory."""

    ckpt_path = f"checkpoints/checkpoint_step_{checkpoint_step}.pt"
    ckpt = torch.load(ckpt_path, map_location=DEVICE)

    model = SimpleTransformer(num_layers=4).to(DEVICE)
    model.load_state_dict(ckpt["model_state"])

    trainable, total = apply_intervention(model, intervention_name)
    actual_wd = 0.0 if intervention_name == "no_weight_decay" else wd

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4, weight_decay=actual_wd,
    )

    train_data = ModularDivisionDataset(P, train=True)
    test_data = ModularDivisionDataset(P, train=False)
    x_test, y_test = test_data.get_batch(256)

    trajectory = []
    for step in range(num_steps):
        model.train()
        x, y = train_data.get_batch(64)
        loss = F.cross_entropy(model(x), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 500 == 0:
            model.eval()
            with torch.no_grad():
                test_acc = (model(x_test).argmax(-1) == y_test).float().mean().item()
                train_acc = (model(x).argmax(-1) == y).float().mean().item()
            trajectory.append({
                "step": checkpoint_step + step,
                "train_acc": train_acc,
                "test_acc": test_acc,
                "loss": loss.item(),
            })
            if step % 2000 == 0:
                print(f"  {intervention_name} @ {checkpoint_step}+{step}: "
                      f"train={train_acc:.3f} test={test_acc:.3f}")

    return {
        "intervention": intervention_name,
        "checkpoint": checkpoint_step,
        "trainable_params": trainable,
        "total_params": total,
        "final_test_acc": trajectory[-1]["test_acc"],
        "max_test_acc": max(t["test_acc"] for t in trajectory),
        "trajectory": trajectory,
    }


if __name__ == "__main__":
    checkpoints = [7000, 11000]
    interventions = [
        "baseline", "no_weight_decay",
        "freeze_head", "freeze_embed",
        "freeze_exit_layer", "freeze_exit_attn", "freeze_exit_mlp",
        "freeze_entry_layer", "freeze_middle_layers",
        "freeze_attn_all", "freeze_mlp_all",
    ]

    results = []
    for ckpt in checkpoints:
        for intervention in interventions:
            print(f"\n[ckpt {ckpt}] {intervention}")
            result = run_ablation(ckpt, intervention)
            results.append(result)

    with open("grokking_ablation_results.json", "w") as f:
        json.dump(results, f, indent=2)

    # Summary table
    print(f"\n{'Intervention':<25} {'From 7K':>12} {'From 11K':>12}")
    print("-" * 50)
    for intervention in interventions:
        r7 = next((r for r in results if r["intervention"] == intervention and r["checkpoint"] == 7000), None)
        r11 = next((r for r in results if r["intervention"] == intervention and r["checkpoint"] == 11000), None)
        print(f"{intervention:<25} {r7['final_test_acc']:>11.1%} {r11['final_test_acc']:>11.1%}")

The results split into three clean categories.

What always blocks grokking

Some interventions kill generalization regardless of when you apply them.

Removing weight decay: test accuracy stays at 3% from step 7,000 and 9% from step 11,000. This confirms what's already established - weight decay is the compressive force that drives grokking [1], [3]. Without it, the network stays in its memorized solution forever.

Freezing the middle layers (L1 and L2): test accuracy stays at 4% and 9%. The middle layers are where the actual computation happens. Freeze them and the network cannot reorganize its internal algorithm.

Freezing all attention layers or all MLP layers: also blocked. Both component types are needed in the middle layers for the generalizing circuit to form.

What never blocks grokking

Freezing the embeddings: the model groks from both checkpoints. 77% from step 7,000, 99% from step 11,000. The embeddings are already "done" before the model even finishes memorizing. They don't need to change for generalization to happen.

This aligns with recent work by Xu et al. [4], who showed that transferring embeddings from a weaker model eliminates the grokking delay entirely. Separately, AlQuabeh et al. [5] showed that MLPs without embeddings generalize immediately - it's the embedding-weight coupling that causes the delay in the first place. Embeddings lock in early. They're infrastructure.

The phase-dependent result

This is the finding I didn't expect. Three components — the entry layer, the exit MLP, and the output head — show completely different behavior depending on when you freeze them.

Freeze the entry layer at step 7,000: blocked. Test accuracy reaches only 32%. Freeze the entry layer at step 11,000: grokked. Test accuracy reaches 96%.

Same component. Same model. Same training. Different checkpoint. Opposite outcome.

Here's what this looks like for all interventions from checkpoint 7,000:

Test accuracy 0% 50% 100% 7K 9.5K 12K 14.5K 17K Training steps (continued from checkpoint 7K) baseline 92% embed 77% exit_attn 67% head 41% entry 32% middle 4% flat: middle layers, all attention, all MLP, no weight decay

From step 7,000, most freezing interventions kill grokking. The trajectories cluster at the bottom. Only the baseline, frozen embeddings, and frozen exit attention reach meaningful test accuracy.

Now the same interventions from step 11,000:

Test accuracy 0% 50% 100% 11K 13.5K 16K 18.5K 21K Training steps (continued from checkpoint 11K) baseline 97% embed 99% entry 96% head 95% exit_mlp 86% middle 9% still flat: middle layers, all attention, all MLP, no weight decay Entry, exit, head: all grok! Infrastructure was already done.

The picture flips. From step 11,000, entry, exit MLP, and the output head can all be frozen and grokking still happens. The lines that were stuck at the bottom from 7K are now climbing to 86-96%.

But the middle layers and weight decay are still essential. Freeze the middle layers from step 11,000 and you get 9% — same as from step 7,000. The computation needs to continue reorganizing regardless.

Between step 7,000 and step 11,000, something completes. The entry layer, exit MLP, and output head finish their job. After that, they're dispensable. Only the middle layers still matter.

The phase boundary

Here's the same data as a direct comparison. Each pair of bars shows the same intervention applied at the two different checkpoints:

Final test accuracy after 10K more steps 0% 20% 40% 60% 80% 100% 32% 96% entry layer 21% 86% exit MLP 41% 95% output head 4% 9% middle layers 3% 9% no wt decay 77% 99% embed (frozen) From ckpt 7K From ckpt 11K Blocked (both) Phase-dependent: blocked → grokked Always essential

The left side of the chart is the finding. Entry layer, exit MLP, output head — all go from blocked (grey) to grokked (green) between the two checkpoints. The right side shows what never changes: middle layers and weight decay are essential at both checkpoints.

The full picture:

InterventionFrom step 7KFrom step 11KCategory
baseline92%97%
remove weight decay3%9%always blocks
freeze middle layers4%9%always blocks
freeze all attention3%39%always blocks
freeze all MLP6%13%always blocks
freeze entry layer32% →96%phase-dependent
freeze exit MLP21% →86%phase-dependent
freeze output head41% →95%phase-dependent
freeze embeddings77%99%never blocks
freeze exit attention67%65%never blocks

Two phases

The data tells a clean story. Grokking isn't one process. It's two sequential processes with a measurable boundary between them.

Phase 1: Infrastructure Steps 0 → ~11,000 Entry layer learns input routing Exit layer builds output interface Output head calibrates predictions Embeddings already done (freezable) After this phase: these components can be frozen without affecting generalization. Phase 2: Computation Steps ~11,000 → grokking Middle layers reorganize circuits Weight decay compresses solution Memorization → generalization Only middle layers + weight decay are needed. Everything else is already locked in.

Phase 1: Infrastructure (steps 0 to ~11,000). The entry layer learns how to route inputs. The exit layer and output head learn how to format outputs. This is happening silently during the memorization plateau — test accuracy doesn't move, but the network is building the scaffolding that generalization will need. The embeddings finish even earlier than this.

Phase 2: Computation (steps ~11,000 to grokking). The infrastructure is locked in. Now the middle layers reorganize their internal circuits from a memorizing algorithm to a generalizing one. Weight decay provides the compressive force. This is the phase where the actual algorithmic insight gets encoded.

The memorization plateau isn't silence. It's construction. The network is building infrastructure that it will later need.

What this doesn't tell us

This experiment has clear limitations.

It's one model, one task, one set of hyperparameters. Modular division on a 4-layer transformer. I haven't tested whether the same two-phase structure appears in larger models, different tasks, or different architectures. It might. It might not.

The phase boundary at step ~11,000 is specific to this run. On a different random seed or with different weight decay, it would land somewhere else. The claim isn't "step 11,000 is special." The claim is "there exists a boundary where component necessity changes, and you can find it by freezing and continuing."

I also haven't explained what the entry layer learns during Phase 1. I can see that it's needed before step 11,000 and not after. But I don't have a mechanistic account of what it's doing. The next experiment would be to analyze the entry layer's attention patterns and MLP activations at the two checkpoints to understand what "infrastructure complete" actually looks like in terms of learned representations.

The three-phase model from Nanda et al. [2] (memorization, circuit formation, cleanup) is related but defined differently. Their phases come from continuous progress measures on Fourier components. Mine come from a binary operational test: freeze it and see if grokking survives. Recent theoretical work [6] frames grokking as a norm-driven representational phase transition, which may explain why the infrastructure phase completes when it does. The two framings are compatible but I haven't formally connected them.

Why this matters

The practical implication is for anyone studying or trying to accelerate grokking. If you know which phase the network is in, you know which parameters to focus on. During Phase 1, gradient updates to the entry and exit layers are doing critical work. During Phase 2, you could freeze them and save compute.

The theoretical implication is more speculative. If transformers consistently build infrastructure before reorganizing computation, that's a principle about how these networks learn. Not "everything changes at once" but "scaffolding first, algorithm second." That pattern — if it holds beyond this one experiment — would be a useful inductive bias for understanding training dynamics more generally.

The full experiment code is embedded above, and the summary results are below.

json grokking_ablation_results.json 201 lines
[
  {
    "intervention": "baseline",
    "checkpoint": 7000,
    "trainable_params": 820736,
    "total_params": 820736,
    "final_test_acc": 0.918,
    "max_test_acc": 0.918,
    "grokked": true
  },
  {
    "intervention": "no_weight_decay",
    "checkpoint": 7000,
    "trainable_params": 820736,
    "total_params": 820736,
    "final_test_acc": 0.0273,
    "max_test_acc": 0.0469,
    "grokked": false
  },
  {
    "intervention": "freeze_head",
    "checkpoint": 7000,
    "trainable_params": 808064,
    "total_params": 820736,
    "final_test_acc": 0.4102,
    "max_test_acc": 0.4336,
    "grokked": false
  },
  {
    "intervention": "freeze_embed",
    "checkpoint": 7000,
    "trainable_params": 806016,
    "total_params": 820736,
    "final_test_acc": 0.7734,
    "max_test_acc": 0.7734,
    "grokked": true
  },
  {
    "intervention": "freeze_exit_layer",
    "checkpoint": 7000,
    "trainable_params": 622464,
    "total_params": 820736,
    "final_test_acc": 0.0703,
    "max_test_acc": 0.0703,
    "grokked": false
  },
  {
    "intervention": "freeze_exit_attn",
    "checkpoint": 7000,
    "trainable_params": 754688,
    "total_params": 820736,
    "final_test_acc": 0.6719,
    "max_test_acc": 0.6719,
    "grokked": true
  },
  {
    "intervention": "freeze_exit_mlp",
    "checkpoint": 7000,
    "trainable_params": 689024,
    "total_params": 820736,
    "final_test_acc": 0.2109,
    "max_test_acc": 0.2109,
    "grokked": false
  },
  {
    "intervention": "freeze_entry_layer",
    "checkpoint": 7000,
    "trainable_params": 622464,
    "total_params": 820736,
    "final_test_acc": 0.3203,
    "max_test_acc": 0.3359,
    "grokked": false
  },
  {
    "intervention": "freeze_middle_layers",
    "checkpoint": 7000,
    "trainable_params": 424192,
    "total_params": 820736,
    "final_test_acc": 0.0352,
    "max_test_acc": 0.043,
    "grokked": false
  },
  {
    "intervention": "freeze_attn_all",
    "checkpoint": 7000,
    "trainable_params": 555520,
    "total_params": 820736,
    "final_test_acc": 0.0312,
    "max_test_acc": 0.0391,
    "grokked": false
  },
  {
    "intervention": "freeze_mlp_all",
    "checkpoint": 7000,
    "trainable_params": 292864,
    "total_params": 820736,
    "final_test_acc": 0.0586,
    "max_test_acc": 0.0625,
    "grokked": false
  },
  {
    "intervention": "baseline",
    "checkpoint": 11000,
    "trainable_params": 820736,
    "total_params": 820736,
    "final_test_acc": 0.9727,
    "max_test_acc": 0.9766,
    "grokked": true
  },
  {
    "intervention": "no_weight_decay",
    "checkpoint": 11000,
    "trainable_params": 820736,
    "total_params": 820736,
    "final_test_acc": 0.0859,
    "max_test_acc": 0.0938,
    "grokked": false
  },
  {
    "intervention": "freeze_head",
    "checkpoint": 11000,
    "trainable_params": 808064,
    "total_params": 820736,
    "final_test_acc": 0.9531,
    "max_test_acc": 0.9531,
    "grokked": true
  },
  {
    "intervention": "freeze_embed",
    "checkpoint": 11000,
    "trainable_params": 806016,
    "total_params": 820736,
    "final_test_acc": 0.9883,
    "max_test_acc": 0.9883,
    "grokked": true
  },
  {
    "intervention": "freeze_exit_layer",
    "checkpoint": 11000,
    "trainable_params": 622464,
    "total_params": 820736,
    "final_test_acc": 0.4492,
    "max_test_acc": 0.6992,
    "grokked": false
  },
  {
    "intervention": "freeze_exit_attn",
    "checkpoint": 11000,
    "trainable_params": 754688,
    "total_params": 820736,
    "final_test_acc": 0.6484,
    "max_test_acc": 0.9727,
    "grokked": true
  },
  {
    "intervention": "freeze_exit_mlp",
    "checkpoint": 11000,
    "trainable_params": 689024,
    "total_params": 820736,
    "final_test_acc": 0.8633,
    "max_test_acc": 0.8984,
    "grokked": true
  },
  {
    "intervention": "freeze_entry_layer",
    "checkpoint": 11000,
    "trainable_params": 622464,
    "total_params": 820736,
    "final_test_acc": 0.957,
    "max_test_acc": 0.957,
    "grokked": true
  },
  {
    "intervention": "freeze_middle_layers",
    "checkpoint": 11000,
    "trainable_params": 424192,
    "total_params": 820736,
    "final_test_acc": 0.0859,
    "max_test_acc": 0.0859,
    "grokked": false
  },
  {
    "intervention": "freeze_attn_all",
    "checkpoint": 11000,
    "trainable_params": 555520,
    "total_params": 820736,
    "final_test_acc": 0.3906,
    "max_test_acc": 0.3906,
    "grokked": false
  },
  {
    "intervention": "freeze_mlp_all",
    "checkpoint": 11000,
    "trainable_params": 292864,
    "total_params": 820736,
    "final_test_acc": 0.1289,
    "max_test_acc": 0.1289,
    "grokked": false
  }
]

References

1. Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022). Grokking: Generalization beyond overfitting on small algorithmic datasets. arXiv:2201.02177

2. Nanda, N., Chan, L., Lieberum, T., Smith, J., & Steinhardt, J. (2023). Progress measures for grokking via mechanistic interpretability. arXiv:2301.05217

3. Liu, Z., Kitouni, O., Nolte, N., Michaud, E. J., Tegmark, M., & Williams, M. (2023). Omnigrok: Grokking beyond algorithmic data. arXiv:2310.06110

4. Xu, Z., et al. (2025). Let me grok for you: Accelerating grokking via embedding transfer from a weaker model. arXiv:2504.13292

5. AlQuabeh, H., et al. (2025). Mechanistic insights into grokking from the embedding layer. arXiv:2505.15624

6. Lyu, K., et al. (2025). The norm-separation delay law of grokking: A first-principles theory of delayed generalization. arXiv:2603.13331