March 22, 2020

Kullback-Leibler divergence

Quick, what's better than one probability distribution?

That's right, two probability distributions.

Once you have two probability distributions, though, it is likely only a matter of time before find yourself comparing them to each other. Are they similar, or different, you might ask, and to what extent? If one of them represents ground truth, how far is the other one from it? Since these are common questions, there are several widely used tools to help answer them.

In this post and an upcoming one, we will explore two widely used ways of quantifying how one distribution differs from another: Kullback-Leibler (KL) divergence (also called relative entropy), and the Wasserstein metric. This post discusses KL divergence. We will define KL divergence, explore it graphically, then look deeper into the definition with a few different interpretations.

The KL divergence of two continuous probability distributions P and Q is given by:

For discrete distributions P and Q, we replace the integral with a summation over the sample space X.

Let's make three preliminary observations from this equation, then explore visually.

  1. First, let's note that the KL divergence of a distribution with itself is zero — because , so for any P. That's a satisfying observation if we're trying to use KL divergence to understand how probability distributions differ, because a distribution doesn't differ from itself.
  2. Second, the KL divergence between two distributions is always nonnegative. We can prove this by applying Jensen's inequality. Jensen's inequality, in its simplest form, tells us that the line connecting two points on a convex curve lies above the curve itself, and more generally tells us that the expectation of a convex function is greater than the convex function of the expectation. We'll use the formulation of Jensen's inequality that tells us that for a probability density function p, a real-valued, measurable function g, and a convex function c, we have:

    .

    With c as the negative logarithm function (it's convex!) and , we conclude that the KL divergence is nonnegative as follows:

  3. Third, note that the two distributions P and Q play different roles in the equation above, so we might correctly expect that, usually, . This lack of equality is important, because we may talk casually of KL divergence as measuring the "distance" between two probability distributions, but we need to be aware that its asymmetry in P and Q makes it quite different from our common intuition of distance. KL divergence satisfies some of our intuition of what a distance is (nonnegative always, and zero when we compare a distribution when itself), but its lack of symmetry (among other properties) means that it fails to meet the mathematical definition of a metric — and metrics more closely match our intuition of a "distance" between two mathematical objects.

Let's now build our intuition for the KL divergence visually. First, we start with a plot of two example probability density functions, p and q. You can experiment with redrawing them, if you'd like, using the 'redraw' button and then clicking and dragging on the graph.

(Under construction: draw slowly for best results, and note that your sketch will be rescaled vertically to make the total probability equal one.)

Now, we consider the log of each of the probability density functions. They are helpful to visualize, because we are working up to the integral of p(x) log (p(x) / q(x)). Let's recall that log(p(x)/q(x)) = log(p(x)) - log(q(x)), so we eventually want to consider the difference of the two log pdfs.

Lastly, we show the difference of the log probabilites, log (p(x) / q(x)). We'll multiply this function by p(x), and then sum up the area under the resulting curve, in gray, with area above the x-axis counting as positive, and below the x-axis counting as negative. The total signed area is the KL divergence!

KL divergence:

There are a number of ways to conceptualize KL divergence. Next, we will consider three.

The approach taken in the graphical breakdown above visualizes the following breakdown of KL divergence:

(This rearragement of the KL divergence expression above comes from remembering from algebra that , of course!)

Thus, KL divergence is the expected difference of the log probabilities of P and Q, where the expectation is taken with respect to P. If you prefer, it is a weighted difference of the log probabilities of P and Q, where the weights come from P. That means that the KL divergences will be large if the log probabilities of P and Q differ in regions of the sample space that are highly probable under P. On the other hand, large differences in the log probablities with respect to P and Q make little impact on the KL divergence if those differences occur over regions or values that have low probability under P.

We should be able to see this property showing up in for the (default) P and Q graphed above. Note how there is a great deal of shaded gray area in the bottom curve (the weighted log difference curve) in the region where P and Q are different and P is large — in the region the second bump of P. However, if we swap P and Q, this same region of the graph contributes little to the shaded area (i.e. it contributes little to the KL divergence), because while P and Q still differ on this region, now the distribution that provides the weighting is small on this region.

This first view of KL divergence can be a helpful way to get a handle on the ideas at play, especially if you're already used to working with log probability density funcitions (which you may well be, especially if you work with distributions computationally, since log-pdfs are often easier to compute with than the pdfs themselves).

Now let's consider a different way to conceptualize KL a divergence, using ideas from coding theory. It will be helpful to consider a discrete example here. Let's say P gives a distribution that assigns probabilities to the three letters, 'p', 'e', 't' of 0.25, 0.5, and 0.25 respectively, as shown in the table below. Now we're going to assign each letter a binary code, trying to make trying to make efficient choices based on how common the letters are. Following a Huffman code that has the nice property that no codeword is a prefix to any other, I encode 'p' with 10, 'e' with 0, and 't' with '11'.

x P(x) code based on P-log2 P(x)
'p' 0.25 '10' 2
'e' 0.5 '0' 1
't' 0.25 '11' 2

Thus, for example, a sequence like p-e-t-e would be encoded 100110. Note that using 0 for 'e' is a nice economical choice, if I think I'm going to be encoding a bunch of these letters and transmitting that code. It is helpful to use the shortest code '0', with the most common letter, 'e'. If I'm communicating a bunch of letters letter-by-letter, and each letter is independent and identically distributed according to P, the expected number of digits (bits) per letter is:

Knowing the true distribution of letters P helped us make an efficient coding choice. But, what if we didn't know P, and just had a guess at it, that we'll call Q? For concreteness, suppose Q gives letters 'p', 'e', 't' probabilities of 0.5, 0.25, and 0.25 respectively. Then, we might create our code differently, giving 0 to 'p', since we think that 'p' is the most common, 10 to 'e', and 11 to 't'. Extending the previous table to now include Q, we have:

x P(x) code based on P -log2 P(x) Q(x) code based on Q -log2 Q(x)
'p' 0.25 '10' 2 0.5 '0' 1
'e' 0.5 '0' 1 0.25 '10' 2
't' 0.25 '11' 2 0.25 '01 2

If we used the code based on Q to encode letters that were actually distributed according to P, the expected number of bits per letter required would be:

Thus, not knowing the true distribution P, and only approximating it with Q, we used 1.75 - 1.5 = 0.25 extra bits per letter. That is, the difference in expected bits per letter is given by:

Wow! This last expression is the KL divergence we saw before — up to a scaling factor for using log base 2 instead of the natural log (to speak in bits instead of nats) and with a sum replacing the integral because we are working with discrete probabilities.

To summarize, the KL divergence captures the idea of the number of extra bits required to encode elements truly distributed under P if we make a code for them based on Q instead of P. Pretty neat!

This may not deserve its own section, but in case you often work with entropies and cross entropies, it is worth pointing out that, expanding the KL divergence as we did in the previous section* as:

we see that the KL divergence is the cross entropy between P and Q minus the entropy of P.

(* with hope for understanding in moving back and forth rather cavalierly between sums and integrals)

Thanks for making it this far! There are a number of ways to compare distributions, and KL divergence is just one. It has some helpful interpretations that can make it more intuitive as a quantity. We should all make sure to remember properties like its asymmetry as well, so that we don't get carried away in our intuition of it as a dissimilarity — let's remember that it's not a metric on distributions.

If you want to see KL divergence in action, take a look at variational Bayes approaches, where we seek to approximate a posterior distribution, and can use KL divergence as the relevant criterion in that approximation. Bayesian Data Analysis has helpful material on the role of KL divergence in variational Bayes (as well as, of course, helpful information on SO MANY other topics!).

As another resource, I found Bishop's introduction to KL divergence in Pattern Recognition and Machine Learning extremely helpful in building up an intuition for it, and this blog post reflects that. It also looks at the role of KL divergence in variational inference.

That's it for now! I'm having fun writing these posts, and hope they might occasionally come at the right time to be helpful to someone reading. Feel free to let me know on Twitter, if you'd like to see a particular statistical topic show up. Or just let me know that you like them, and I'll take that as encouragement to write more. Take care out there.