Expectation Maximization if a kind of probabilistic method to classify data. Please correct me if I am wrong if it is not a classifier.
What is an intuitive explanation of this EM technique? What is expectation here and what is being maximized?
Expectation Maximization if a kind of probabilistic method to classify data. Please correct me if I am wrong if it is not a classifier.
What is an intuitive explanation of this EM technique? What is expectation here and what is being maximized?
The accepted answer references the Chuong EM Paper, which does a decent job explaining EM. There is also a youtube video that explains the paper in more detail.
To recap, here is the scenario:
In the case of the first trial's question, intuitively we'd think B generated it since the proportion of heads matches B's bias very well... but that value was just a guess, so we can't be sure.
With that in mind, I like to think of the EM solution like this:
This may be an oversimplification (or even fundamentally wrong on some levels), but I hope this helps on an intuitive level!
Note: the code behind this answer can be found here.
Suppose we have some data sampled from two different groups, red and blue:
Here, we can see which data point belongs to the red or blue group. This makes it easy to find the parameters that characterise each group. For example, the mean of the red group is around 3, the mean of the blue group is around 7 (and we could find the exact means if we wanted).
This is, generally speaking, known as maximum likelihood estimation. Given some data, we compute the value of a parameter (or parameters) that best explains that data.
Now imagine that we cannot see which value was sampled from which group. Everything looks purple to us:
Here we have the knowledge that there are two groups of values, but we don't know which group any particular value belongs to.
Can we still estimate the means for the red group and blue group that best fit this data?
Yes, often we can! Expectation Maximisation gives us a way to do it. The very general idea behind the algorithm is this:
These steps need some further explanation, so I'll walk through the problem described above.
Example: estimating mean and standard deviation
I'll use Python in this example, but the code should be fairly easy to understand if you're not familiar with this language.
Suppose we have two groups, red and blue, with the values distributed as in the image above. Specifically, each group contains a value drawn from a normal distribution with the following parameters:
Here is an image of these red and blue groups again (to save you from having to scroll up):
When we can see the colour of each point (i.e. which group it belongs to), it's very easy to estimate the mean and standard deviation for each each group. We just pass the red and blue values to the builtin functions in NumPy. For example:
But what if we can't see the colours of the points? That is, instead of red or blue, every point has been coloured purple.
To try and recover the mean and standard deviation parameters for the red and blue groups, we can use Expectation Maximisation.
Our first step (step 1 above) is to guess at the parameter values for each group's mean and standard deviation. We don't have to guess intelligently; we can pick any numbers we like:
These parameter estimates produce bell curves that look like this:
These are bad estimates. Both means (the vertical dotted lines) look far off any kind of "middle" for sensible groups of points, for instance. We want to improve these estimates.
The next step (step 2) is to compute the likelihood of each data point appearing under the current parameter guesses:
Here, we have simply put each data point into the probability density function for a normal distribution using our current guesses at the mean and standard deviation for red and blue. This tells us, for example, that with our current guesses the data point at 1.761 is much more likely to be red (0.189) than blue (0.00003).
For each data point, we can turn these two likelihood values into weights (step 3) so that they sum to 1 as follows:
With our current estimates and our newly-computed weights, we can now compute new estimates for the mean and standard deviation of the red and blue groups (step 4).
We twice compute the mean and standard deviation using all data points, but with the different weightings: once for the red weights and once for the blue weights.
The key bit of intuition is that the greater the weight of a colour on a data point, the more the data point influences the next estimates for that colour's parameters. This has the effect of "pulling" the parameters in the right direction.
We have new estimates for the parameters. To improve them again, we can jump back to step 2 and repeat the process. We do this until the estimates converge, or after some number of iterations have been performed (step 5).
For our data, the first five iterations of this process look like this (recent iterations have stronger appearance):
We see that the means are already converging on some values, and the shapes of the curves (governed by the standard deviation) are also becoming more stable.
If we continue for 20 iterations, we end up with the following:
The EM process has converged to the following values, which turn out to very close to the actual values (where we can see the colours - no hidden variables):
In the code above you may have noticed that the new estimation for standard deviation was computed using the previous iteration's estimate for the mean. Ultimately it does not matter if we compute a new value for the mean first as we are just finding the (weighted) variance of values around some central point. We will still see the estimates for the parameters converge.
Using the same article by Do and Batzoglou cited in Zhubarb's answer, I implemented EM for that problem in Java. The comments to his answer show that the algorithm gets stuck at a local optimum, which also occurs with my implementation if the parameters thetaA and thetaB are the same.
Below is the standard output of my code, showing the convergence of the parameters.
Below is my Java implementation of EM to solve the problem in (Do and Batzoglou, 2008). The core part of the implementation is the loop to run EM until the parameters converge.
Below is the entire code.
EM is an algorithm for maximizing a likelihood function when some of the variables in your model are unobserved (i.e. when you have latent variables).
You might fairly ask, if we're just trying to maximize a function, why don't we just use the existing machinery for maximizing a function. Well, if you try to maximize this by taking derivatives and setting them to zero, you find that in many cases the first-order conditions don't have a solution. There's a chicken-and-egg problem in that to solve for your model parameters you need to know the distribution of your unobserved data; but the distribution of your unobserved data is a function of your model parameters.
E-M tries to get around this by iteratively guessing a distribution for the unobserved data, then estimating the model parameters by maximizing something that is a lower bound on the actual likelihood function, and repeating until convergence:
The EM algorithm
Start with guess for values of your model parameters
E-step: For each datapoint that has missing values, use your model equation to solve for the distribution of the missing data given your current guess of the model parameters and given the observed data (note that you are solving for a distribution for each missing value, not for the expected value). Now that we have a distribution for each missing value, we can calculate the expectation of the likelihood function with respect to the unobserved variables. If our guess for the model parameter was correct, this expected likelihood will be the actual likelihood of our observed data; if the parameters were not correct, it will just be a lower bound.
M-step: Now that we've got an expected likelihood function with no unobserved variables in it, maximize the function as you would in the fully observed case, to get a new estimate of your model parameters.
Repeat until convergence.
Other answers being good, i will try to provide another perspective and tackle the intuitive part of the question.
EM (Expectation-Maximization) algorithm is a variant of a class of iterative algorithms using duality
Excerpt (emphasis mine):
Usually a dual B of an object A is related to A in some way that preserves some symmetry or compatibility. For example AB = const
Examples of iterative algorithms, employing duality (in the previous sense) are:
In a similar fashion, the EM algorithm can also be seen as two dual maximization steps:
In an iterative algorithm using duality there is the explicit (or implicit) assumption of an equilibrium (or fixed) point of convergence (for EM this is proved using Jensen's inequality)
So the outline of such algorithms is:
Note that when such an algorithm converges to a (global) optimum, it has found a configuration which is best in both senses (i.e in both the x domain/parameters and the y domain/parameters). However the algorithm can just find a local optimum and not the global optimum.
i would say this is the intuitive description of the outline of the algorithm
For the statistical arguments and applications, other answers have given good explanations (check also references in this answer)
Here is a straight-forward recipe to understand the Expectation Maximisation algorithm:
1- Read this EM tutorial paper by Do and Batzoglou.
2- You may have question marks in your head, have a look at the explanations on this maths stack exchange page.
3- Look at this code that I wrote in Python that explains the example in the EM tutorial paper of item 1:
Warning : The code may be messy/suboptimal, since I am not a Python developer. But it does the job.