期望最大化技术的直观解释是什么?

期望最大化(EM)是一种概率数据分类方法。如果我错了,请纠正我,如果它不是一个分类器。

这种 EM 技术的直观解释是什么? 这里的 expectationmaximized是什么?

45297 次浏览

EM 用于最大化潜变量 Z 的模型 Q 的可能性。

这是一个迭代优化。

theta <- initial guess for hidden parameters
while not converged:
#e-step
Q(theta'|theta) = E[log L(theta|Z)]
#m-step
theta <- argmax_theta' Q(theta'|theta)

电子步骤: 给定 Z 的电流估计,计算期望的对数似然函数

步骤: 找到使这个 Q 最大化的 θ

GMM Example:

E-step: 根据当前的 gmm 参数估计估计每个数据点的标签分配

M-step: 最大化给定新标签分配的新 θ

K- 均值也是一种 EM 算法,有很多关于 K- 均值的解释动画。

EM 是一种当模型中的一些变量未被观察到时(例如,当你有潜在变量时)使似然函数最大化的算法。

你可能会问,如果我们只是想最大化一个函数,为什么不用现有的机制来最大化一个函数呢。如果你尝试通过求导把它们设置为零来最大化这个问题,你会发现在很多情况下,一阶条件没有解。这里存在一个先有鸡还是先有蛋的问题,为了解决模型参数,你需要知道未观测数据的分布; 但是未观测数据的分布是模型参数的函数。

E-M 试图通过迭代猜测未观测数据的分布,然后通过最大化实际似然函数的下界来估计模型参数,并且重复直到收敛来绕过这个问题:

电磁算法

从猜测模型参数的值开始

E-step: 对于每个缺失值的数据点,使用你的模型方程来解决缺失数据的分布,给出你当前对模型参数的猜测和观测数据(注意,你解决的是每个缺失值的分布,而不是期望值)。现在我们有了每个缺失值的分布,我们可以计算似然函数相对于未观测变量的 期望。如果我们对模型参数的猜测是正确的,那么这个预期的可能性将是我们观测数据的实际可能性; 如果参数不正确,它将只是一个下限。

M 步: 现在我们已经得到了一个期望的似然函数,其中没有未观察到的变量,最大化的函数,因为你会在完全观察的情况下,得到一个新的估计,你的模型参数。

Repeat until convergence.

从技术上来说,“ EM”这个术语有点含糊不清,但我假设你指的是高斯混合模型数据聚类技术,这是一般 EM 原理的 例子

其实是 电磁数据聚类不是分类器。我知道有些人认为集群是“无监督分类”,但实际上数据聚类是完全不同的。

关键的区别在于,人们对数据聚类的分类总是有一个很大的误解,那就是: 在聚类分析中,没有“正确的解决方案”。这是一个知识 发现的方法,它实际上是为了找到 新的的东西!这使得评估非常棘手。通常使用已知的分类作为参考进行评估,但这并不总是合适的: 您所拥有的分类可能反映或可能不反映数据中的内容。

让我给你举个例子: 你有一个庞大的客户数据集,包括性别数据。将此数据集划分为“男性”和“女性”的方法在与现有类进行比较时是最佳的。从“预测”的角度来看,这是好事,至于新用户,你现在可以预测他们的性别。在一种“知识发现”的思维方式中,这实际上是不好的,因为您希望在数据中发现一些 新的结构。一种方法,例如将数据分成老年人和儿童,然而将得分 越糟越好相对于男性/女性班级。然而,这将是一个非常好的集群结果(如果没有给出年龄的话)。

现在回到 EM。本质上,它假设您的数据是由多个多元正态分布组成的(请注意,这是一个 非常强假设,特别是当您修复集群数量时!).然后通过 交替改进模型和模型的对象分配试图找到一个局部最优模型。

For best results in a classification context, choose the number of clusters 更大 than the number of classes, or even apply the clustering to 单身 classes only (to find out whether there is some structure within the class!).

假设你想训练一个分类器来区分“汽车”、“自行车”和“卡车”。假设数据恰好由3个正态分布组成是没有什么用的。然而,你可以假设 有不止一种车(和卡车和自行车)。因此,不需要为这三个类训练一个分类器,而是将汽车、卡车和自行车分成10个集群(或者可能是10辆汽车、3辆卡车和3辆自行车,等等) ,然后训练一个分类器来区分这30个类,然后将类结果合并回原始类。您可能还会发现有一个集群特别难以分类,例如 Trikes。有点像汽车,有点像自行车。或者送货卡车,它们更像是超大型汽车而不是卡车。

下面是理解期望最大化算法的一个简单方法:

1- 阅读 Do 和 Batzoglou 的 电磁学教学论文

2- You may have question marks in your head, have a look at the explanations on this maths stack exchange 呼叫.

3- 看看我用 Python 编写的代码,它解释了项目1的 EM 教程论文中的示例:

Warning : The code may be messy/suboptimal, since I am not a Python developer. But it does the job.

import numpy as np
import math


#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* ####


def get_mn_log_likelihood(obs,probs):
""" Return the (log)likelihood of obs, given the probs"""
# Multinomial Distribution Log PMF
# ln (pdf)      =             multinomial coeff            *   product of probabilities
# ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]


multinomial_coeff_denom= 0
prod_probs = 0
for x in range(0,len(obs)): # loop through state counts in each observation
multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
prod_probs = prod_probs + obs[x]*math.log(probs[x])


multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
likelihood = multinomial_coeff + prod_probs
return likelihood


# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45


# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)


# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50


# E-M begins!
delta = 0.001
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
expectation_A = np.zeros((5,2), dtype=float)
expectation_B = np.zeros((5,2), dtype=float)
for i in range(0,len(experiments)):
e = experiments[i] # i'th experiment
ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B


weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A
weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B


expectation_A[i] = np.dot(weightA, e)
expectation_B[i] = np.dot(weightB, e)


pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A));
pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B));


improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
j = j+1

使用 Do 和 Zhubarb 答案中引用的 Batzoglou 的同一篇文章,我在 爪哇咖啡中针对这个问题实现了 EM。对他的回答的评论表明,算法陷入了局部最优,如果参数 thetaA 和 thetaB 相同,我的实现也会出现这种情况。

下面是我的代码的标准输出,显示了参数的收敛性。

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

下面是我在 EM 中用于解决这个问题的 Java 实现(Do and Batzoglou,2008)。实现的核心部分是运行 EM 直到参数收敛的循环。

private Parameters _parameters;


public Parameters run()
{
while (true)
{
expectation();


Parameters estimatedParameters = maximization();


if (_parameters.converged(estimatedParameters)) {
break;
}


_parameters = estimatedParameters;
}


return _parameters;
}

下面是整个代码。

import java.util.*;


/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
double _thetaA = 0.0; // Probability of heads for coin A.
double _thetaB = 0.0; // Probability of heads for coin B.


double _delta = 0.00001;


public Parameters(double thetaA, double thetaB)
{
_thetaA = thetaA;
_thetaB = thetaB;
}


/*************************************************************************
Returns true if this parameter is close enough to another parameter
(typically the estimated parameter coming from the maximization step).
*************************************************************************/
public boolean converged(Parameters other)
{
if (Math.abs(_thetaA - other._thetaA) < _delta &&
Math.abs(_thetaB - other._thetaB) < _delta)
{
return true;
}


return false;
}


public double getThetaA()
{
return _thetaA;
}


public double getThetaB()
{
return _thetaB;
}


public String toString()
{
return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
}


}




/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
double _numHeads = 0;
double _numTails = 0;


public Observation(String s)
{
for (int i = 0; i < s.length(); i++)
{
char c = s.charAt(i);


if (c == 'H')
{
_numHeads++;
}
else if (c == 'T')
{
_numTails++;
}
else
{
throw new RuntimeException("Unknown character: " + c);
}
}
}


public Observation(double numHeads, double numTails)
{
_numHeads = numHeads;
_numTails = numTails;
}


public double getNumHeads()
{
return _numHeads;
}


public double getNumTails()
{
return _numTails;
}


public String toString()
{
return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
}


}


/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
// Current estimated parameters.
private Parameters _parameters;


// Observations from the trials. These observations are set once.
private final List<Observation> _observations;


// Estimated observations per coin. These observations are the output
// of the expectation step.
private List<Observation> _expectedObservationsForCoinA;
private List<Observation> _expectedObservationsForCoinB;


private static java.io.PrintStream o = System.out;


/*************************************************************************
Principal constructor.
@param observations The observations from the trial.
@param parameters The initial guessed parameters.
*************************************************************************/
public EM(List<Observation> observations, Parameters parameters)
{
_observations = observations;
_parameters = parameters;
}


/*************************************************************************
Run EM until parameters converge.
*************************************************************************/
public Parameters run()
{


while (true)
{
expectation();


Parameters estimatedParameters = maximization();


o.printf("%s\n", estimatedParameters);


if (_parameters.converged(estimatedParameters)) {
break;
}


_parameters = estimatedParameters;
}


return _parameters;


}


/*************************************************************************
Given the observations and current estimated parameters, compute new
estimated completions (distribution over the classes) and observations.
*************************************************************************/
private void expectation()
{


_expectedObservationsForCoinA = new ArrayList<Observation>();
_expectedObservationsForCoinB = new ArrayList<Observation>();


for (Observation observation : _observations)
{
int numHeads = (int)observation.getNumHeads();
int numTails = (int)observation.getNumTails();


double probabilityOfObservationForCoinA=
binomialProbability(10, numHeads, _parameters.getThetaA());


double probabilityOfObservationForCoinB=
binomialProbability(10, numHeads, _parameters.getThetaB());


double normalizer = probabilityOfObservationForCoinA +
probabilityOfObservationForCoinB;


// Compute the completions for coin A and B (i.e. the probability
// distribution of the two classes, summed to 1.0).


double completionCoinA = probabilityOfObservationForCoinA /
normalizer;
double completionCoinB = probabilityOfObservationForCoinB /
normalizer;


// Compute new expected observations for the two coins.


Observation expectedObservationForCoinA =
new Observation(numHeads * completionCoinA,
numTails * completionCoinA);


Observation expectedObservationForCoinB =
new Observation(numHeads * completionCoinB,
numTails * completionCoinB);


_expectedObservationsForCoinA.add(expectedObservationForCoinA);
_expectedObservationsForCoinB.add(expectedObservationForCoinB);
}
}


/*************************************************************************
Given new estimated observations, compute new estimated parameters.
*************************************************************************/
private Parameters maximization()
{


double sumCoinAHeads = 0.0;
double sumCoinATails = 0.0;
double sumCoinBHeads = 0.0;
double sumCoinBTails = 0.0;


for (Observation observation : _expectedObservationsForCoinA)
{
sumCoinAHeads += observation.getNumHeads();
sumCoinATails += observation.getNumTails();
}


for (Observation observation : _expectedObservationsForCoinB)
{
sumCoinBHeads += observation.getNumHeads();
sumCoinBTails += observation.getNumTails();
}


return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));


//o.printf("parameters: %s\n", _parameters);


}


/*************************************************************************
Since the coin-toss experiment posed in this article is a Bernoulli trial,
use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
*************************************************************************/
private static double binomialProbability(int n, int k, double p)
{
double q = 1.0 - p;
return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
}


private static long nChooseK(int n, int k)
{
long numerator = 1;


for (int i = 0; i < k; i++)
{
numerator = numerator * n;
n--;
}


long denominator = factorial(k);


return (long)(numerator / denominator);
}


private static long factorial(int n)
{
long result = 1;
for (; n >0; n--)
{
result = result * n;
}


return result;
}


/*************************************************************************
Entry point into the program.
*************************************************************************/
public static void main(String argv[])
{
// Create the observations and initial parameter guess
// from the (Do and Batzoglou, 2008) article.


List<Observation> observations = new ArrayList<Observation>();
observations.add(new Observation("HTTTHHTHTH"));
observations.add(new Observation("HHHHTHHHHH"));
observations.add(new Observation("HTHHHHHTHH"));
observations.add(new Observation("HTHTTTHHTT"));
observations.add(new Observation("THHHTHHHTH"));


Parameters initialParameters = new Parameters(0.6, 0.5);


EM em = new EM(observations, initialParameters);


Parameters finalParameters = em.run();


o.printf("Final result:\n%s\n", finalParameters);
}
}

Other answers being good, i will try to provide another perspective and tackle the intuitive part of the question.

EM (期望最大化)算法 是使用 < a href = “ http://en.wikipedia.org/wiki/Duality _% 28數学% 29”rel = “ nofollow”> 对偶性的一类迭代算法的变体

节选(重点是我的) :

在数学中,一般来说,二元性翻译概念, theorems or mathematical structures into other concepts, theorems or 以一对一的方式,经常(但不总是)通过 对合运算的对偶: 如果 A 的对偶是 B,那么 B 的对偶 is A. Such involutions 有时有固定点, so that the dual 就是 A 本身

通常 对象A 的 dualB 与 A 有某种关系,这种关系保留了一些 对称性或兼容性。例如 AB = 康斯特

Examples of iterative algorithms, employing duality (in the previous sense) are:

  1. 辗转相除法最大公约数及其变种
  2. Gram-Schmidt 矢量基算法及其变体
  3. 算术平均-几何平均不等式及其变体
  4. 期望-最大化算法及其变体 (另见 这里有一个信息几何视图)
  5. (. . 其他类似的算法. .)

In a similar fashion, EM 算法也可以看作是两个对偶最大化步骤:

. . [ EM ]被视为最大化的联合函数的参数和 未观测变量的分布. . E 步最大化 这个函数关于未被观测的分布 变量; 关于参数的 M 步。

在一个使用对偶的迭代算法中,存在一个显式(或隐式)的收敛点的平衡(或不动)假设(对于 EM,这是用 Jensen 不等式证明的)

因此,这些算法的大纲是:

  1. 类 E 步骤: 找到关于给定 保持不变的最佳解 X
  2. M-like step (dual): Find best solution with respect to X (as computed in previous step) being held constant.
  3. 终止/收敛的判据步骤: 重复步骤1,2,更新 x的值,直到收敛(或达到指定的迭代次数)

注意: 当这种算法收敛到一个(全局)最优时,它已经找到了一个 在两种意义上都是最好的配置(即在 X域/参数和 域/参数中)。然而,该算法只能找到一个 本地的最佳,而不是 全球性的的最佳。

我认为这是算法的直观描述

对于统计论据和应用,其他的答案已经给出了很好的解释(请检查这个答案中的参考文献)

公认的答案参考了 中电子纸,它很好地解释了 EM。还有一个 youtube video,解释了更详细的文件。

简而言之,情况是这样的:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails


Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.


We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

In the case of the first trial's question, intuitively we'd think B generated it since the proportion of heads matches B's bias very well... but that value was just a guess, so we can't be sure.

考虑到这一点,我认为 EM 解决方案是这样的:

  • 每次掷硬币的试验都要“投票”选出它最喜欢的硬币
    • 这取决于每个硬币的分布情况
    • 或者,从硬币的角度来看,有很高的 expectation看到这个试验相对于其他硬币(基于 记录可能性)。
  • 根据每个试验对每个硬币的喜爱程度,它可以更新对该硬币参数的猜测(偏差)。
    • 一个试验越喜欢一个硬币,它就越能更新硬币的偏见来反映它自己!
    • 本质上,硬币的偏差是通过在所有试验中结合这些加权更新来更新的,这个过程被称为(最大化) ,它指的是在一系列试验中试图对每个硬币的偏差获得最佳猜测。

This may be an oversimplification (or even fundamentally wrong on some levels), but I hope this helps on an intuitive level!

注意: 这个答案背后的代码可以找到 here


假设我们有一些来自两个不同组的数据,红色和蓝色:

enter image description here

在这里,我们可以看到哪个数据点属于红色或蓝色组。这使得很容易找到描述每个群体的参数。例如,红色组的平均值在3左右,蓝色组的平均值在7左右(如果需要,我们可以找到确切的平均值)。

这通常被称为 极大似然估计极大似然估计。给定一些数据,我们计算一个参数(或参数)的值,以最好地解释该数据。

现在想象一下,我们 cannot看看哪个值是从哪个组中抽取的,所有东西在我们看来都是紫色的:

enter image description here

这里我们知道有 值组,但是我们不知道任何特定的值属于哪个组。

Can we still estimate the means for the red group and blue group that best fit this data?

是的,我们经常可以!期望最大化为我们提供了一种方法。这个算法的基本思想是这样的:

  1. Start with an initial estimate of what each parameter might be.
  2. 计算每个参数生成数据点的 可能性
  3. 根据参数生成数据点的可能性,计算每个数据点的权重,指示它是更红还是更蓝。将权重与数据结合起来(期望)。
  4. 使用权重调整数据(最大化)对参数进行更好的估计。
  5. 重复步骤2到4,直到参数估计收敛(过程停止生成不同的估计)。

这些步骤需要进一步的解释,因此我将详细介绍上面描述的问题。

例如: 估计均值和标准差

在这个示例中,我将使用 Python,但是如果您不熟悉这种语言,那么代码应该相当容易理解。

假设我们有两组,红色和蓝色,其值分布如上图所示。具体来说,每个组包含一个从 正态分布中提取的值,该值具有以下参数:

import numpy as np
from scipy import stats


np.random.seed(110) # for reproducible results


# set parameters
red_mean = 3
red_std = 0.8


blue_mean = 7
blue_std = 2


# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)


both_colours = np.sort(np.concatenate((red, blue))) # for later use...

下面是这些红色和蓝色组的图片(为了避免你不得不向上滚动) :

enter image description here

当我们看到每个点的颜色(即它属于哪个组) ,就很容易估计每个组的平均值和标准差。我们只是将红色和蓝色的值传递给 NumPy 中的内置函数。例如:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

但是,如果我们 不行看到的点的颜色?也就是说,不是红色或蓝色,而是每一个点都被染成了紫色。

为了尝试恢复红色和蓝色组的平均值和标准差参数,我们可以使用期望最大化。

我们的第一步(上面的 步骤1)是猜测每个组的平均值和标准差的参数值。我们不必聪明地猜测; 我们可以选择任何我们喜欢的数字:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9


# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

这些参数估计产生的钟形曲线如下:

enter image description here

这些都是错误的估计。例如,这两种表示(垂直虚线)看起来都远离任何类型的“中间”点。我们希望改进这些估计。

下一步(第二步)是计算出现在当前参数猜测下的每个数据点的可能性:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

在这里,我们只是简单地将每个数据点输入到 概率密度函数中,使用我们目前对红色和蓝色的平均值和标准差的猜测,得到一个正态分布。例如,这告诉我们,根据我们目前的猜测,1.761处的数据点 很多更可能是红色(0.189)而不是蓝色(0.00003)。

For each data point, we can turn these two likelihood values into weights (step 3) so that they sum to 1 as follows:

likelihood_total = likelihood_of_red + likelihood_of_blue


red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

根据我们目前的估计和新计算的权重,我们现在可以计算红色和蓝色组(步骤4)的平均值和标准差的 新的估计值。

我们使用 all数据点两次计算平均值和标准差,但加了不同的权重: 一次是红色权重,一次是蓝色权重。

直觉的关键点在于,一种颜色在数据点上的权重越大,数据点对该颜色参数的下一个估计值的影响就越大。这就产生了将参数“拉向”正确方向的效果。

def estimate_mean(data, weight):
"""
For each data point, multiply the point by the probability it
was drawn from the colour's distribution (its "weight").


Divide by the total weight: essentially, we're finding where
the weight is centred among our data points.
"""
return np.sum(data * weight) / np.sum(weight)


def estimate_std(data, weight, mean):
"""
For each data point, multiply the point's squared difference
from a mean value by the probability it was drawn from
that distribution (its "weight").


Divide by the total weight: essentially, we're finding where
the weight is centred among the values for the difference of
each data point from the mean.


This is the estimate of the variance, take the positive square
root to find the standard deviation.
"""
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)


# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)


# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

我们对参数有了新的估计。为了再次改进它们,我们可以跳回到步骤2并重复这个过程。我们这样做直到估计收敛,或者在执行了一些迭代之后(步骤5)。

对于我们的数据,这个过程的前五个迭代如下所示(最近的迭代有更强的外观) :

enter image description here

我们看到平均值已经在某些值上趋于一致,曲线的形状(由标准差控制)也变得更加稳定。

如果我们继续进行20次迭代,我们得到的结果如下:

enter image description here

EM 过程已经收敛到以下值,结果非常接近实际值(在这里我们可以看到颜色-没有隐藏的变量) :

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

在上面的代码中,你可能已经注意到新的标准差估计是使用前一次迭代的平均值估计来计算的。最终,我们是否首先计算平均值的新值并不重要,因为我们只是找到围绕某个中心点的值的(加权)方差。我们仍将看到参数的估计收敛。