Learning Dynamics of llm fine-tuning

Unlocking the Black Box: How LLMs Really Learn During Fine-Tuning§

Large Language Models (LLMs) have revolutionized AI, yet how they truly "learn" during fine-tuning often feels like a black box. We see the impressive results, but the step-by-step process of how a model refines its knowledge remains opaque.

This post dives into a groundbreaking paper, "Learning Dynamics Of LLM Fine-Tuning" by Yi Ren and Danica J. Sutherland. Their work, validated on models from the Pythia (e.g., Pythia-2.8B) and Qwen1.5 families using standard alignment datasets like Anthropic-HH, provides a powerful mathematical lens to observe the intricate dance of learning, one training step at a time.

The Core Question: How Does Learning Ripple Through the Model?§

Most research focuses on the final state of a fine-tuned model. This paper, however, investigates the journey. It asks a fundamental question: When we teach a pre-trained model something new with one example, how does that single update ripple through the model to affect its predictions on other, seemingly unrelated examples?

The authors' investigation yields several key contributions:

  • A Mathematical Framework: A precise way to decompose and analyze the learning process for methods like Supervised Fine-Tuning (SFT) and Direct Preference Optimization (DPO).
  • A New Explanation for Hallucination: Moving beyond simple "context mixing" to a more fundamental, feature-based explanation.
  • Discovery of the "Squeezing Effect": Uncovering an unexpected and sometimes harmful side effect of teaching models what not to say.
  • A Practical Mitigation Strategy: Offering a solution to counteract this negative effect.

The Mathematical Heart of Learning§

To understand the learning dynamics, the authors developed a core equation. It might look intimidating, but it elegantly breaks down how a single training update influences the model's predictions.

$$ \underbrace{\Delta \log \pi^t(\mathbf{y} | \mathbf{x}o)}{\text{Change in Output}} = -\eta \underbrace{\mathcal{A}^t(\mathbf{x}o)}{\text{State Term}} \underbrace{\mathcal{K}^t(\mathbf{x}_o, \mathbf{x}u)}{\text{Similarity Term}} \underbrace{\mathcal{G}^t(\mathbf{x}_u, \mathbf{y}u)}{\text{Energy Term}} + \mathcal{O}(\eta^2) $$

Let's break this down. Imagine we're fine-tuning a model on the MNIST dataset of handwritten digits. We show it an image of a '4' (the updating example, $x_u$) and want to see how that affects its prediction for an image of a '9' (the observing example, $x_o$).

1. The Energy Term ($\mathcal{G}$): The Learning Signal§

$$ \mathcal{G}^{t}(\mathbf{x}{u}, \mathbf{y}{u}) = \nabla_{\mathbf{z}}\mathcal{L}(\mathbf{x}{u}, \mathbf{y}{u})|_{z^{t}} $$

This is the engine of the update. It measures how wrong the model's current prediction for '4' is and determines the direction and intensity of the correction. If the model was very confident it was a '2', the energy is high, signaling a large adjustment is needed. If it was already leaning toward '4', the energy is low. It's the model's internal "error compass," pointing where to go and how hard to push.

2. The Similarity Term ($\mathcal{K}$): The Ripple Effect§

$$ \mathcal{K}^{t}(\mathbf{x}{o}, \mathbf{x}{u}) = (\nabla_{\theta}\mathbf{z}(\mathbf{x}{o})|{\theta^{t}})(\nabla_{\theta}\mathbf{z}(\mathbf{x}{u})|{\theta^{t}})^{\top} $$

This is the Empirical Neural Tangent Kernel (eNTK), and it's a crucial piece of the puzzle. It measures the similarity between the updating example ('4') and the observing example ('9') from the model's internal perspective. If the model's internal representations for '4' and '9' are similar (perhaps due to their visual curves), the kernel value will be high. This means a lesson learned about '4' will have a strong influence on how the model sees '9'. It’s the mechanism for generalization.

A key assumption the authors make is that this internal sense of similarity remains stable throughout fine-tuning. A '4' and a '9' are similar at the beginning and stay similarly related throughout. They provide strong empirical validation for this in Figures 7 and 8 of the paper's appendix, showing the eNTK's structure remains remarkably consistent from the start to the end of training.

3. The State Term ($\mathcal{A}$): The Mediator§

$$ \mathcal{A}^{t}(\mathbf{x}{o}) = \mathbf{I} - \mathbf{1}\pi{\theta^{t}}^{\top}(\mathbf{x}_{o}) $$

This term acts as a mediator, translating the changes from the energy and similarity terms into the final output probabilities. It considers the model's current confidence in its prediction for the observing example ('9'). It fine-tunes the update, ensuring the final probabilities across all possible digits are coherent and sum to one.

In essence, the model learns something from one example (Energy), this learning spreads to other examples based on how similar the model perceives them to be (Similarity), and a final adjustment is made based on the model's current state of knowledge (State).

Explaining Puzzling Phenomena§

This framework provides powerful explanations for behaviors we've long observed in LLMs.

A New Angle on Hallucinations§

Hallucination is often described as the model "making things up." This framework suggests a more nuanced cause: feature-level association.

When a model is trained on a data point (e.g., prompt $x_u$ and answer $y_u$), the learning doesn't just link that specific prompt and answer. The eNTK ($\mathcal{K}$) causes the update to generalize to other prompts ($x_o$) that the model considers similar. If the underlying features of $x_u$ and $x_o$ are strongly linked in the model's "mind," a strong pull-up pressure on answer $y_u$ can create an association between $x_o$ and $y_u$.

Later, when prompted with $x_o$, the model might generate $y_u$ not because it's confused, but because the fine-tuning process taught it that the features of these prompts lead to that answer. It's less about mixing up context and more about over-generalizing from a learned feature association.

The "Squeezing Effect" in DPO§

Direct Preference Optimization (DPO) teaches a model to prefer one response ($y^+$) over another ($y^-$). Its gradient is effectively a "pull-up" pressure on the good response and a "push-down" pressure on the bad one. While DPO successfully widens the gap between the preferred and rejected answers, it has a strange side effect: the absolute probability of generating both answers often decreases. Figure 4 in the paper visualizes this perfectly, showing the log-probabilities of both chosen and rejected answers declining over time.

The authors name this The Squeezing Effect.

When DPO applies negative pressure to a rejected answer ($y^-$), it squeezes the probability mass out of it. Because the total probability must sum to 1, this mass has to go somewhere. Counter-intuitively, it doesn't just flow to the preferred answer ($y^+$). Instead, it can disproportionately flow to the single most probable token in the entire vocabulary for that step, let's call it the "winner-takes-all" token.

This effect is most damaging when the rejected response ($y^-$) already had a very low probability. The paper explains in Appendix E that this happens because the gradient update for any token is influenced by the probabilities of all other tokens. When the rejected token's probability is already minuscule, its influence on the gradient calculation is weakened. As a result, the "pull" from the most likely token in the vocabulary becomes dominant. The model overcompensates to satisfy the gradient update, taking probability mass not just from the rejected token but from all other options—including the desired one ($y^+$)—to feed the winner. The result is that the preferred answer becomes less likely, while the "winner-takes-all" token's probability skyrockets.

The Fix: The authors propose a simple and effective mitigation: perform a gentle SFT update on the rejected response ($y^-$) before the DPO step. This "pulls it up" from the low-probability valley, ensuring the DPO push-down is less extreme and avoids the catastrophic squeezing effect. Figure 5 of the paper shows this fix in action, demonstrating that the probability of the chosen answer now correctly increases during DPO.

Questions for the Road Ahead§

Like any great discovery, this paper answers some questions while raising fascinating new ones. It leaves us with a richer, more complex picture of LLM alignment and pushes us to think more deeply about how we build and guide these models:

  • Probing the Squeezing Effect: Can the "winner-takes-all" token be predicted ahead of time? Could identifying examples in the probability "valleys" allow us to create dynamic training sets—skipping harmful updates or focusing on more productive ones? This shifts the view of a dataset from a static resource to a dynamic partner in the training process.
  • Feature-Level Association and Societal Bias: The paper explains that models link underlying features, not just concrete examples. Could this be a root cause of bias amplification? A model might not explicitly learn "doctor = man," but could it develop strong feature-level associations between the concepts, leading to subtle but persistent biases in its responses?
  • The Risk of Catastrophic Forgetting: Could the squeezing effect lead to a form of catastrophic forgetting? When probability is siphoned away from the long tail of the distribution, we risk erasing valuable knowledge about rare events, niche topics, or minority cultures that happen to lie in those "valleys." How do we protect this knowledge while still guiding the model's preferences?