Interpretation of MLE in terms of KL divergence

In terms of parametric function approximation, MLE minimizes the KL divergence

Suppose that the true density of a random variable $x$ is $p(x)$. Since this is unknown, we can try to come up with an approximation $q(x)$. Then KL divergences is a good measure of mismatch between $p$ and $q$ distribution.

$$ \begin{align*} \text{KL divergence:}\quad KL(p||q) = \int p(x)\log \dfrac{p(x)}{q(x)}dx \end{align*} $$

From the formula we can see that KL divergence is a weighted average, with wighted $p(x)$, of an error induced by approximation ($\log p(x) - \log q(x)$). A good approximation $q$ would be the one that minimizes KL divergence, and as KL divergence is a functional, function approximation is a functional optimization problem in essence. (Functional is a mapping that takes in function as an input and prints out a value, e.g. real number, as an output.)

However, since KL divergence is NOT symmetric, the question remains whether to use forward $KL(p||q)$ or reverse $KL(q||p)$ to optimize. The decision depends on the form of the true distribution $p(x)$, and if we believe $p$ to be multimodal, then we might use reverse rather then forward KL divergence. This would be the topic of the next blog post.

For now, let us assume that we try to minimize $KL(p||q)$. Plus, we can make an assumption that $q(x\mid \theta)\in Q(\theta)$, that is, we restrict the class of approximation function to a parametric class of function. Since each function in this parametric class is completely determined by the parameters $\theta$, the functional optimization problem is reduced to function optimization problem.

The problem is, in order to calculate $KL(p(x)||q(x\mid \theta))$, we need to know the true density $p$ in the first place. This problem can be easily circumvented. Suppose we have a finite amount of data $x_1, x_2,…,x_N$ all of which are iid random samples. Since each data is generated by the true density $p$, by the law of large number we can approximate an expectation of some statistic $E_p[f(x)]$ by $\sum_n f(x_n)/N$. Similarly, KL divergence can be approximated by

$$ \begin{align*} \min_{q\in Q(\theta)} KL(p(x)||q(x\mid \theta)) &=\int p(x)\log p(x)-p(x)\log q(x\mid \theta) dx\\\
&=E_p[\log p(x) - \log q(x\mid \theta)]\\\
&\approx \dfrac{1}{N}\sum_{n=1}^N \Big[\log p(x_n) - \log q(x_n\mid \theta)\Big] \end{align*} $$

Note that in the last line, the first term $\log p(x_n)$ is independent of the choice of $q$, so it is irrelevant for the purpose of optimization. The second term $\sum_n \log q(x_n \mid \theta)/N$ is a log likelihood of the data $x_1, x_2,…,x_N$ under the model $q(x_n\mid \theta)$. Therefore, the choice of $\theta_{mle}$ that maximizes this log likelihood is equivalent to minimizing the KL divergence $KL(p(x)||q(x\mid \theta))$.

References

  1. Pattern Recognition and Machine Learning, Bishop, 2006