Skip to content
C Codeloom
AI

AI Knowledge Distillation Explained

How knowledge distillation lets a small student model learn from a large teacher model, why it works, and how to use it to ship smaller, faster models in production.

·5 min read · By Codeloom
Intermediate 10 min read

What you'll learn

  • What knowledge distillation actually transfers
  • The teacher-student training setup
  • A concrete distillation loop in PyTorch-style code
  • Trade-offs vs quantization and pruning
  • Tips for getting a strong student model

Prerequisites

  • Basic deep learning familiarity

Big models are wonderful research artifacts and painful production dependencies. Knowledge distillation is the standard technique for turning a heavy, accurate model into a smaller, cheaper one that keeps most of the accuracy. This post walks through how it works.

What and Why

Knowledge distillation trains a small student model to imitate a large teacher model. Instead of only learning from hard labels in the dataset, the student also learns from the teacher’s soft probability distribution over outputs. Those soft targets carry richer information about the structure of the problem.

You distill when inference cost matters: on-device models, low-latency APIs, or large fleets where every millisecond compounds. A 7B teacher might score 88 on a benchmark; a distilled 1B student might score 84. That gap can be worth a 7x speedup and 7x cost reduction.

Mental Model

A classifier’s logits before softmax encode more than just the top class. They encode how similar the model thinks each class is to the input. A picture of a husky might give high probability to husky, smaller but nontrivial probability to malamute, and tiny probability to airplane. Those relative magnitudes are dark knowledge.

The teacher distills that knowledge into soft targets by applying softmax with a temperature greater than one, which flattens the distribution. The student learns to match this softened distribution using KL divergence, alongside the regular cross-entropy on the true label.

Hands-on Example

Here is a sketch of a distillation training step. The teacher is frozen; only the student is updated.

import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    soft_student = F.log_softmax(student_logits / T, dim=-1)
    soft_teacher = F.softmax(teacher_logits / T, dim=-1)
    kd = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (T * T)
    ce = F.cross_entropy(student_logits, labels)
    return alpha * kd + (1 - alpha) * ce

for batch in loader:
    with torch.no_grad():
        t_logits = teacher(batch.x)
    s_logits = student(batch.x)
    loss = distillation_loss(s_logits, t_logits, batch.y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

The temperature T controls how soft the distribution becomes. Higher temperatures spread probability mass and expose more of the dark knowledge. The T * T factor keeps the gradient magnitudes comparable to the hard-label term.

           input batch
              |
     +--------+--------+
     v                 v
 [teacher]         [student]
 (frozen)          (training)
     |                 |
 logits/T          logits/T
     |                 |
 softmax          log_softmax
      \               /
       v             v
      KL divergence loss
              +
      cross entropy on true label
              |
              v
      backprop into student only
Teacher-student distillation flow

After training, the teacher is discarded. Only the small student goes to production.

Trade-offs

Distillation usually beats training the small model from scratch on the same data, especially when labels are scarce or noisy. The teacher acts as a smoother, denser supervisor.

It does not match the teacher’s accuracy. Expect to recover 90 to 98 percent of the gap depending on the task and how aggressive the size reduction is. If you need parity, distillation alone is not enough.

Compared to quantization, distillation changes the model architecture; quantization changes the numeric precision. They compose well. Distill first to get a smaller model, then quantize for an extra speedup. Pruning is a third axis and can be stacked too.

Training cost is real. You pay for teacher inference over every batch for every epoch. For large teachers and large datasets, pre-computing and caching teacher logits is often worthwhile.

Practical Tips

Use a temperature between 2 and 8 for most classification tasks. Tune on a validation set. Lower temperatures behave more like hard-label training; higher ones spread mass too thin.

Distill on a larger unlabeled corpus than your labeled set. The teacher provides labels for free, which is one of the most underrated benefits of the technique.

Match the input distribution carefully. If the teacher saw clean images and the student will serve noisy phone camera shots, distill on the noisy distribution or you will be disappointed in production.

For LLMs, distill on outputs the teacher actually generates, not just on shared training data. Sequence-level distillation, where the student learns to imitate teacher-generated text, often beats token-level distillation.

Cache teacher outputs when training data is fixed. It can cut training time in half and lets you iterate on student architectures quickly.

Wrap-up

Knowledge distillation is the cleanest way to get small models that punch above their weight. The technique is simple, the implementation is a few lines, and the win in inference cost is often the difference between an experiment and a shipped product. Pair it with quantization and you have a powerful compression stack that scales from edge devices to high-throughput APIs.