21 minute read

A couple of really phenomenal recent papers have brought up new approaches to self distillation in language models and reinforcement learning, and motivated me to spend some time thinking about self distillation as a general bootstrapping concept. I spend a lot of time working with DINO- and JEPA-adjacent models, where the concept is fairly well understood, but I was still quite surprised by the application to language modeling, and wanted to dig a bit deeper into underlying principles.

Unfortunately, the results of my dig are a little complicated. BYOL is typically cited as the origin of self-distillation, and that’s certainly the case for image representation learning. However, I want to argue that it isn’t actually the first self-distillation technique in modern AI research. Broadly, I’d would say learning is self-distillation if the loss signal comes from a grad-stopped version of the network—possibly subject to different augmentations, or with compute to improve the teacher’s signal. As I argue below, essentially all bootstrapped reinforcement learning methods do something like this. More recent RL papers bring similar ideas to actor networks, in a pretty complicated way.

In this post I will (non-chronologically) go through a few important papers that performed some form of self-distillation, and then discuss what I found to be the major themes. We begin with with the computer vision setting, where very extensive analyses and ablations of self-distillation for Joint-Encoding Prediction Architectures (JEPAs) have been conducted. I then briefly touch on SFT by self-distillation, before conducting an extended discussion on reinforcement learning and self-distillation, touching on bootstrapping, AlphaZero, and more recent work.

BYOL, DINO, SWaV and co: self-distillation for computer vision

One of the big themes through this research track is asymmetries between the teacher and student models. In general, the problem to keep in mind is dimensional collapse: if our network outputs an \(n\)-dimensional vector, we want the entire space to be used for representing different images. Pathologies are things like every image being represented by a single point (total collapse) or the space of representations being a small subspace of the available representation space. This paper suggested that pretty much all of the asymmetries in these bootstrapped designs corresponds to preventing dimensional collapse. It includes a deep dive on the math, and a great discussion of why the asymmetric projection heads of BYOL and SimSiam work, that I don’t plan to go into. The overarching theme, however, is that we want to give the teacher desirable characteristics that the student must then attempt to replicate.

BYOL

Bootstrap Your Own Latent was the first paper to bring teacher/student self-distillation to the fore, and brought with it many important innovations notably the momentum encoder, or exponential moving average (EMA) teacher model, and the asymmetric projection head. BYOL is a form of self-supervised learning, meaning the goal is to obtain good image representations without using image labels. The general approach to (JEPA) SSL is to generate multiple augmentations of an image (e.g. changing the exposure and the crop) and using a loss function that incentivizes these outputs to be similar: even though the values of all of the pixels are different, the images have the same semantic content and therefore should have very similar representations under our neural network. BYOL’s goal was to avoid representation collapse without using negative, or contrastive pairs.

In more detail, BYOL’s big ideas were:

  • Use self distillation as a training technique for JEPA: their loss function is very similar to SimCLR’s, except without any negative pairs. The fundamental idea is breaking SimCLR’s symmetry in two ways: a student/teacher model, and an asymmetric projection head.
  • We want to create a teacher model which integrates “good” behaviors of the student model, but we don’t really have good gradients for the teacher. BYOL set the teacher model \(T_{t+1}\) at time step \(t+1\) to update to be \(T_{t+1} = \alpha T_t + (1 - \alpha) S_{t+1}\), where \(\alpha\) is generally set to be pretty close to 1. This is called the momentum encoder, or EMA teacher model. Importantly, compared to something like SimCLR, gradients only flow through the student branch, which in turn causes the teacher to keep moving. This helps prevent the model from isolating a single fixed state.
  • Also interestingly, BYOL used an asymmetric projection head \(P\). Instead of comparing output representations directly, given two augmentations \(x, x'\) of an input image, BYOL’s loss function is \(\operatorname{sim}(P(S(x)), T(x')) + \operatorname{sim}(P(S(x')), T(x))\). Compare this with SimCLR, which calculates similarity as \(\operatorname{sim}((P(S(x)), P(S(x'))\). (The similarity operator throughout is generally set to be cosine similarity.)

These asymmetries are the fundamental ideas BYOL introduced to stabilize self-distillation—but as we will see, neither idea is truly necessary: SimSiam showed that EMA is unnecessary, and SWAV and DINO symmetrize the projection head.

SimSiam

SimSiam was a very interesting follow-up to the BYOL paper. It adopts the full BYOL architecture, and performs a wide range of ablations, particularly on the two components of BYOL we isolated above.

The biggest surprise of SimSiam was that the momentum encoder, which set a moving target for the student, is not necessary for convergence, even good convergence, of the set-up. The pivotal feature of the EMA teacher, according to SimSiam is that the EMA breaks gradients, so gradients only flow through the student. Replacing BYOL’s teacher \(T\) with \(\operatorname{stop-grad} \circ S\) is fully functional, even though the model \(S\) is being compared with its own logits, through a projector head.

This should be surprising for lots of reasons: people generally use teacher models, as they outperform student models in applications, which suggests the EMA is really achieving something special. The fact that the student and teacher don’t have identical weights in BYOL is also appealing: in expectation, the best map the projection head \(P\) can learn is the identity mapping, something which is not the case when \(S\) and \(T\) have different weights!

On the other hand, SimSiam validated the projection head, Their ablations here are a great example of experiments I wish more papers would include when they propose architectures: SimSiam tried removing the projection head, as well as fixing it to its initialization weights intead of training it. In both cases, training fails, but for different reasons.

  • Without a projection head, the loss function \(\operatorname{sim}(S(x), \operatorname{stop-grad} \circ S(x')) + \operatorname{sim}(S(x'), \operatorname{stop-grad} \circ S(x))\) is equivalent to \(\frac12 (\operatorname{sim}(S(x), S(x')) + \operatorname{sim}(S(x'), S(x)))\), and the model simply collapses to a point without SimCLR’s negative pairs, achieving perfect loss and giving zero information.
  • With a fixed, randomly initialized projection head, the model doesn’t collapse. However, neither does it converge, with loss simply remaining high. The authors conclude that the projection must be allowed to adapt to the representations.

In later work it was suggested that the projection head functions as as a low-pass filter on the student’s feature representations, meaning it dampens the low eigenvalues of the student’s representation space. (The eigenvalues of the feature correlation matrix of the data.) This causes the bootstrap to increase the effective dimension in order to match the teacher’s outputs.

SWaV

SwaV takes an online-clustering approach to representation learning. The original paper doesn’t present this as a student/teacher set-up, nor as self-distillation—but it is! My exposition is I believe correct, but quite different from the original presentation. Like SimSiam, SWaV uses a grad-stopped version of the student model as the teacher, rather than a momentum encoder. Instead of an asymmetric prediction head, SWaV maintains a pool of cluster prototypes—or equivalently, linear projection head applied to both student and teacher representations. The student model’s output is computed as \(\operatorname{softmax}(P \circ S(x))\). This can suffer from dimensional collapse though—so to enrich the representation, the teacher model \(T=S\) uses the Sinkhorn-Knopp algorithm to make the minibatch output \(\operatorname{Sinkhorn-Knopp}(P \circ T(B))\) be doubly stochastic, meaning that each image \(x\) is assigned to a probability distribution over the set of prototypes (rows of \(P\)) , but also that each prototype is assigned an equal weight value.

The loss is then computed as cross entropy of the student output \(\operatorname{softmax}(P \circ S(B))\) and the teacher output \(\operatorname{stop-grad} \circ \operatorname{Sinkhorn-Knopp}(P \circ T(B))\). Importantly, the projection head \(P\) is still used, but now is applied to both the student and teacher branches. The only asymmetry is that the teacher model uses minibatch statistics to reweight outputs with the Sinkhorn-Knopp algorithm.

For all that I find SWaV one of the most confusing papers in a confusing field, it has really wonderful ideas!

  • Asymmetric projections probably are necessary, but that asymmetry doesn’t need to happen at the level of the projection head. Instead, we can use test-time compute and batch statistics (here in the form of Sinkhorn-Knopp) to improve the signal that the teacher has for the student.
  • SWaV’s method of preventing dimensional collapse is much more explicit than BYOL/SimSiam’s, and Sinkhorn-Knopp directly optimizes for higher-dimensional representations through architecture decisions, whereas BYOL uses implicit training dynamics of the projection head.
  • The SWaV paper also introduced new multi-crop augmentations. In DINO, these augmentations would create a new asymmetry in the student/teacher set-up, but SWaV I believe maintained data augmentation symmetry.

I’ve deliberately set-up SWaV in an evocative way: by describing it through the self-distillation lens, several natural questions come up: can we replace the teacher model with a momentum encoder? Can we use a nonlinear projection head \(P\), instead of a linear one? Along with simplifying Sinkhorn-Knopp, these are the key ideas that contributed to creating DINO.

DINO

DINO was the spiritual successor to SWaV, while bringing in several innovations from BYOL, and is probably the most famous/successful project in this line of work. DINO introduces the momentum encoder from BYOL, but while maintaining SWaV’s symmetric projection head. We can quickly summarize the big ideas of DINO:

  • A nonlinear projection head for both student and teacher: instead of having the projection head only be applied to the student like in BYOL, an EMA teacher projection head is calculated along with the EMA teacher backbone. Loss is then calculated by \(\operatorname{sim}(P \circ S(x), \hat{P} \circ \hat{S}(x'))\), where \(\hat{A}\) denotes an EMA of the values of \(A\). Unlike SWaV, \(P\) can now be an arbitrary network, rather than a single matrix.
  • The loss function is modified to prevent collapse of this set-up: DINO employs centering and sharpening to the teacher’s outputs. That is, the ideal output distribution should be 0 centered, so DINO directly subtracts the mean teacher representation from \(\hat{P} \circ \hat{S}(x)\). Additionally, like SWaV, DINO replaces cosine similarity with cross entropy, and uses a higher temperature for the student distribution than the teacher distribution. The idea is that the model should be confident about the features of the image, and not collapse to the uninformative 0 point. (In fact, some related ideas had been used in label smoothing in supervised learning.) This is a computationally simpler tool than SWaV’s Sinkhorn-Knopp algorithm that accomplishes a similar goal of keeping the teacher distribution from collapsing to a few dimensions/prototypes.
  • The augmentation distribution is changed: in both SimSiam and BYOL (and I think SWaV, but I’d have to look at the code to be sure) the same distribution of augmentations was used for both the student and teacher branch of the network. DINO makes the argument that the teacher model should be proposing as good a representation as possible, an idealized representation, that the student network tries to approximate. They therefore give the teacher network higher resolution training images, with lower distortion, and the student network an array of smaller, more distorted images.

The difference in augmentations can also be thought of as a justification of the sharpening of the teacher model: given a better image, the network should be more confident about its data, and therefore have sharper outputs.

As a result, we were able to remove the asymmetric projection head—but with some clear asterisks! We had to introduce two new forms of asymmetry to the system to eke out high performance: the centering/sharpening of DINO makes the failure modes of self-distillation much more explicit, and actively controls the output distribution shape. Additionally, we use asymmetric data distributions to continue to promote diversity in the teacher branch. This has further benefits, allowing us to decrease the cost of running the student network on large batches, by running them on smaller crops.

Self-distillation in SFT and regular distillation

I’ll be brief here, as I’m not an expert on language model distillation/supervised fine tuning, but one repeated observation throughout the field has been the necessity of distilling on the student’s generated text, instead of just matching distributions on arbitrary text. Smaller models are less expressive, and therefore can’t fully mimic large models outside of their primary distribution. Therefore, the best approach to distillation is to have the student generate text, and the teacher add corrections. In this way, the student’s output distribution is refined in its most useful domain.

This can be viewed as one of the general problems with SFT for language models: if we fine tune on text that’s outside the distribution of the model we can get messy results that interfere with its standard generation, rather than just accumulating knowledge as desired. In RL terms, standard SFT is off-policy, which is always harder. (This paper that I really enjoyed discusses using RL techniques for off-policy learning in SFT, and is worth checking out for a deeper dive on this topic.) Recently it was proposed that there’s a natural way to make SFT exist in-distribution: by having the network itself paraphrase sample text, or “say it in your own words.”

The set-up here is quite simple: a model is prompted with “Here’s a sample answer to query Q: […] Now give your own answer to Q:,” or a similar prompt. Then the student model, which is prompted only with Q, is distilled to the teacher’s distribution, which is predicated on sample data, but fully in distribution of the model.

It could definitely be argued that this is sort of a trick to deal with small models being less expressive—and certainly, it’s hard to know how necessary this approach is once models reach trillions of tokens, and less and less text becomes out of distribution. But for small models, it’s nice, and I think we shouldn’t dismiss it out of hand. As a self distillation technique, it integrates ideas from both SimSiam and DINO: the teacher is just a grad-stopped version of the student, and the augmentation strategy is similar to DINO’s, giving better input data to the teacher than to the student.

We can also view this as a hint about the success of self-distillation in computer vision: by bootstrapping representations, the model’s approximation goals are always achievable/in domain, which helps stabilize the training process. I wouldn’t invest too much in this explanation though—ultimately, the generative/discriminative gap is quite large, and we shouldn’t necessarily expect insight to travel easily between them.

Reinforcement learning: an origin story for self-distillation

At the risk of being controversial, I want to argue that self-distillation is quite old in reinforcement learning! I’m taking a broad definition here, letting self-distillation refer to any learning set-up in which the learning signal is a grad-stopped version of the network’s, or a related (e.g. EMA) network’s, output.

In classical reinforcement learning algorithms like \(Q\)-learning or TD-learning some variant of Bellman backups is employed to create a value table or value network. In its simplest form, the value of an action \(a\) at state \(s\) must satisfy \(Q(s, a) = r(a) + \gamma \max_{a'} Q(s', a')\), where \(\gamma\) is a discount factor, \(r(a)\) is the reward associated to taking action \(a\), and \(s'\) is the resultant state of taking action \(a\) at state \(s\).[^1] There’s a natural backward-inductive solution to approximating \(Q\) directly from this equation: iteratively update the approximated value \(\hat{Q}(s, a)\) by setting it equal to \(r(a) + \gamma \max_{a'} \hat{Q}(s', a')\). When we’re using neural nets for \(\hat{Q}\), we instead treat this as gradient information, with our loss function \((\hat{Q}(s, a) - \operatorname{stop-grad}(r(a) + \gamma \max_{a'}\hat{Q}(s', a')))^2\).

In other words, in \(Q\)-learning, the teacher network \(\operatorname{stop-grad} \circ \hat{Q}\) updates the student network \(\hat{Q}\) by observing the outcome of the evaluated action in order to refine its approximation of the action value. The teacher network in this original setting shares identical weights with the student network, very much like SimSiam (see above). DQN changed the set-up to keep a fixed target network for the teacher network, updated periodically (e.g. every 1000 steps) in order to stabilize learning and prevent large oscillations. This can be viewed as a rough version of the momentum encoder used by BYOL and DINO—and in fact, under the name Polyak averaging, momentum encoders are now a preferred way to implement the teacher network in this setting! (See e.g. SAC, which in turn took the momentum encoder value from this paper.)

The only remaining difference between the RL set-up and something like DINO or BYOL is the difference in projection heads (RL doesn’t need them) and the form of data asymmetry used: DINO provides different augmentations of images to the teacher and student, and performs normalizations on the teacher output, while RL uses environment signals to refine the teacher’s value estimate.

One thing I really like about this analogy is that it clarifies the role EMA has in DINO and BYOL’s set-up: SimSiam and SWaV had revealed that the EMA wasn’t necessary, but it still seemed helpful. Stability of representations was the natural guess, but this is much more evident in the reinforcement learning context. Given that set-ups like BYOL and DINO face real threats of representation collapse if things go wrong, the EMA likely helps stabilize the gradients for a small boost in performance.

AlphaGo & AlphaZero: the actor network can have a bit of self-distillation. As a treat.

Typically, the actor network in actor-critic set-ups like PPO is not trained with something like self-distillation: the critic function is used to calculate the advantage of each action, and the actor network updates in proportion to the advantage. This is necessary to some extent: the actor network needs to get loss signals from the critic in order to learn non-random actions. However, it is possible to do something that interpolates between advantage estimates and self-distillation!

The first AI paper that I read was AlphaZero’s explanation of how they surpassed humans at Go. It was an incredible moment, that has influenced a huge amount of following research. AlphaZero was trained with Monte Carlo Tree Search (MCTS) and leaf expansion following the Upper Confidence Bound (UCB) calculation, to rapidly explore/plan sequences of moves. But the training algorithm for AlphaZero wasn’t to have the planning algorithm purely learn to assign higher probabilities to nodes with higher value estimations.[^2] Rather, the planner was trained with cross entropy loss to match the distribution of states sampled by the MCTS algorithm.

There are lots of reasons this worked out well: it gives dense signals about all moves considered, and perfectly interlaces with test time compute via MCTS at test time. It allows values to feed into the planner, but still incentivizes exploration. But if you squint, it looks like self distillation! AlphaZero’s form of self-distillation looks a bit different from our computer vision and SFT examples, but carries ideas from both of them: instead of presenting different augmentations, the teacher model is given access to test-time compute, and uses that to create a better solution, similar to generative SFT above. But we aren’t really worrying about perfectly matching the starting distribution, like SFT is, rather we’re using features of the data space (turn order, and the way the board evolves when different moves are played) in order to abstractly gain structural information about planning sequences.

To be fully explicit, in AlphaZero the student and teacher share the same weights. For a given move, the teacher model \(P\) iteratively uses the upper confidence bound on to select from the set of available actions \(Q(s, a) + \alpha \frac{P(a \mid s)\sqrt{N(s)}}{N(s, a) + 1}\), where \(\alpha\) is a constant, \(Q(s, a)\) is the value model’s approximation of the value of \(a\), \(P(a \mid s)\) is the teacher’s probability for action \(a\) at state \(s\), and \(N(s)\) (resp. \(N(s, a)\)) are the visit count to the state (resp. state/action pair) used to balance exploration. As test time compute scales, the contribution to the UCB from \(P\) decreases while the \(Q\) values become more accurate. But the objective function for \(P\) is the cross entropy \(H(P; N)\) of \(P\) with the visit distribution. If we ignore the contribution from \(Q\), \(N\) would just be a (slightly higher entropy) version of \(P\), which can be viewed as a kind of entropy regularization. (Note this is inverse to the sharpening that we discussed in DINO.) On the other hand, if we just used \(Q(s, a)\) (appropriately normalized to be a stable advantage estimate), we would return to the standard actor-critic update, with a slightly more sophisticated variation of \(Q\)-learning. So AlphaZero can pretty reasonably be thought of as an interpolation between self-distillation and classic RL updates! The gradients of the teacher are stopped with the relative sledgehammer of using \(N\) as the teacher signal instead of \(P\).

We again see the big difference between the RL and the computer vision approach is that RL uses compute to make improved estimates for the teacher, while computer vision uses augmentations to make the student’s task more difficult—and it’s not clear that either approach translates productively to the other domain.

RLVR for language models: fully self-distilled actors

The most common approaches to RLVR in the literature are PPO and GRPO. PPO of course is an actor critic, but GRPO had successfully removed the self-distillation of the critic network by simply not having a critic network. It’s quite neat, therefore, that new approaches to actor-only RL have arisen that are fully self-distilled!

Sharing several authors with the highlighted SFT/continual learning paper above, this paper starts from the same basic idea: use a language model’s in-context learning to provide supervision signals. When rich feedback exists, such as an error message or a failed unit test when code is run, the teacher model is given both the error message and the original attempt transcript and is evaluated. The ratio between the student and teacher probabilities for each token can be viewed as a surrogate advantage function, in which the teacher will assign lower probabilities than the student will to mistakes, since the teacher is prompted with the error message.

This is very tidy and very clever, of course! The authors did find that an EMA teacher was necessary to avoid collapse (though I hypothesize DQN’s fixed teacher epochs would likely also work), supporting our earlier interpretation of the momentum encoder stabilizing targets. Like the variety of Bellman-backup methods for value estimation, the self distillation signal comes from providing environment feedback as an extra signal to the teacher network. This is fairly different from the AlphaZero/DINO set-ups, which as summarized above scale teacher compute and student augmentations.

My biggest surprise from this work was that using the external/GRPO rewards in combination with the student/teacher rewards stopped being useful as model size grew. My intuition was that providing the relative advantage information to the network would also be useful, and some linear combination of GRPO with in-context advantage estimates would outperform either case alone. It turns out this is true only for small models, where the in-context learning teacher is on the weaker side. The authors found that as model size grew, GRPO started providing a less useful signal.

Conclusions

We summarized a wide array of different self-distillation frameworks (and also more RL than I intended to discuss when I started writing!) and a few themes exist throughout:

  • Asymmetry between the teacher and the student is important—and whenever possible, it’s good to give the teacher advantages. The form of advantage given to the teacher is highly context-dependent. Bringing self-distillation to new contexts will involve finding new ways to deploy advantages to the teacher network—or disadvantages to the student network.
  • Stability is a big concern—and momentum encoders are a powerful tool to preserve it. AlphaZero’s version of stabilization, by using observed samples rather than network outputs, is also really interesting, and sort of out there as self-distillation frameworks go. It really lacks obvious analogy outside the MCTS world, but I can’t help but wonder if I’m missing something.
  • Self distillation is also fundamentally different between RL and vision research. RL integrates true signals from the environment as signal to the teacher network, which are generally considered unavailable in (JEPA) vision research. I’m skeptical of the amount of cross-talk the two methods can have here: making reinforcement learning harder with student augmentations seems extremely problematic in many scenarios—RL is already hard enough! Curating good augmentations to improve the RL setting is probably quite difficult. On the other hand, JEPA research generally doesn’t assume you have access to anything more than the ability to do augmentations. I’m not sure what other signals you could use, for vision tasks!
  • However, the two are also in communication, and many ideas should translate nicely between them. One thing I’m particularly intrigued about is scaling compute for teachers in JEPA models, similar to the AlphaZero use of scaled planning in the teacher. To do this on something like Imagenet, you would need to use an architecture that allowed flexible compute, which is not at all standard in vision literature. However, candidates do exist: recent recurrent-layer transformers could be integrated with ViTs to give the teacher more compute than the student. A reasonable first question is whether the use of deeper networks powered by recurrence increases or decreases the effective dimension of the space. A continuous reasoning model could also be interesting here.

Footnotes

[^1] For simplicity, we assume a deterministic system.

[^2] Though this does happen implicitly through the UCB calculation.

Updated: