• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

R语言实现混合正态分布EM最大期望估计法

武飞扬头像
拓端研究室TRL
帮助1

因为近期在分析数据时用到了EM最大期望估计法这个算法,在参数估计中也用到的比较多。然而,发现国内在R软件上实现高斯混合分布的EM的实例并不多,大多数是关于1到2个高斯混合分布的实现,不易于推广,因此这里分享一下自己编写的k个高斯混合分布的EM算法实现请大神们多多指教。并结合EMCluster包对结果进行验算。

      本文使用的密度函数为下面格式:


学新通

   对应的函数原型为 em.norm(x,means,covariances,mix.prop)

x为原数据,means为初始均值,covariances为数据的协方差矩阵,mix.prop为混合参数初始值。

使用的数据为MASS包里面的synth.te数据的前两列

x <- synth.te[,-3]

首先安装需要的包,并读取原数据。

  1.  
    install.packages("MASS")
  2.  
     
  3.  
    library(MASS)
  4.  
     
  5.  
    install.packages("EMCluster")
  6.  
     
  7.  
    library(EMCluster)
  8.  
     
  9.  
    install.packages("ggplot2")
  10.  
     
  11.  
    library(ggplot2)
  12.  
     
  13.  
    Y=synth.te[,c(1:2)]
  14.  
     
  15.  
    qplot(x=xs, y=ys, data=Y) 
学新通

然后绘制相应的变量相关图:

学新通

从图上我们可以大概估计出初始的平均点为(-0.7,0.4) (-0.3,0.8)(0.5,0.6)

当然 为了试验的严谨性,我可以从两个初始均值点的情况开始估计

首先输入初始参数:

  1.  
    mustart = rbind(c(-0.5,0.3),c(0.4,0.5))    
  2.  
     
  3.  
    covstart = list(cov(Y), cov(Y))
  4.  
     
  5.  
    probs = c(.01, .99)

然后编写em.norm函数,注意其中的clusters值需要根据不同的初始参数进行修改,

  1.  
    em.norm = function(X,mustart,covstart,probs){
  2.  
     
  3.  
     
  4.  
     
  5.  
      params = list(mu=mustart, var=covstart, probs=probs)   
  6.  
     
  7.  
      clusters = 2 
  8.  
     
  9.  
      tol=.00001
  10.  
     
  11.  
      maxits=100
  12.  
     
  13.  
      showits=T
  14.  
     
  15.  
      require(mvtnorm)
  16.  
     
  17.  
     
  18.  
     
  19.  
      N = nrow(X)
  20.  
     
  21.  
      mu = params$mu
  22.  
     
  23.  
      var = params$var
  24.  
     
  25.  
      probs = params$probs
  26.  
     
  27.  
      
  28.  
     
  29.  
      
  30.  
     
  31.  
      ri = matrix(0, ncol=clusters, nrow=N)         
  32.  
     
  33.  
      ll = 0                                        
  34.  
     
  35.  
      it = 0                                         
  36.  
     
  37.  
      converged = FALSE                            
  38.  
     
  39.  
      
  40.  
     
  41.  
      if (showits)                                 
  42.  
     
  43.  
        cat(paste("Iterations of EM:", "\n"))
  44.  
     
  45.  
      
  46.  
     
  47.  
      while (!converged & it < maxits) { 
  48.  
     
  49.  
        probsOld = probs
  50.  
     
  51.  
        
  52.  
     
  53.  
        llOld = ll
  54.  
     
  55.  
        riOld = ri
  56.  
     
  57.  
        
  58.  
     
  59.  
       
  60.  
     
  61.  
        # Compute responsibilities
  62.  
     
  63.  
        for (k in 1:clusters){
  64.  
     
  65.  
          ri[,k] = probs[k] * dmvnorm(X, mu[k,], sigma = var[[k]], log=F)
  66.  
     
  67.  
        }
  68.  
     
  69.  
        ri = ri/rowSums(ri)
  70.  
     
  71.  
        
  72.  
     
  73.  
      
  74.  
     
  75.  
        rk = colSums(ri)                             
  76.  
     
  77.  
        probs = rk/N
  78.  
     
  79.  
        for (k in 1:clusters){
  80.  
     
  81.  
          varmat = matrix(0, ncol=ncol(X), nrow=ncol(X))         
  82.  
     
  83.  
          for (i in 1:N){
  84.  
     
  85.  
            varmat = varmat ri[i,k] * X[i,]%*%t(X[i,])
  86.  
     
  87.  
          }
  88.  
     
  89.  
          mu[k,] = (t(X) %*% ri[,k]) / rk[k]
  90.  
     
  91.  
          var[[k]] =  varmat/rk[k] - mu[k,]%*%t(mu[k,])
  92.  
     
  93.  
          ll[k] = -.5 * sum( ri[,k] * dmvnorm(X, mu[k,], sigma = var[[k]], log=T) )
  94.  
     
  95.  
        }
  96.  
     
  97.  
        ll = sum(ll)
  98.  
     
  99.  
        
  100.  
     
  101.  
         
  102.  
     
  103.  
        parmlistold =  c(llOld, probsOld)            
  104.  
     
  105.  
        parmlistcurrent = c(ll, probs)             
  106.  
     
  107.  
        it = it 1
  108.  
     
  109.  
        if (showits & it == 1 | it%%5 == 0)         
  110.  
     
  111.  
          cat(paste(format(it), "...", "\n", sep = ""))
  112.  
     
  113.  
        converged = min(abs(parmlistold - parmlistcurrent)) <= tol
  114.  
     
  115.  
      }
  116.  
     
  117.  
      
  118.  
     
  119.  
      clust = which(round(ri)==1, arr.ind=T)       
  120.  
     
  121.  
      clust = clust[order(clust[,1]), 2]           
  122.  
     
  123.  
      out = list(probs=probs, mu=mu, var=var, resp=ri, cluster=clust, ll=ll)
  124.  
     
  125.  
  126.  
     
学新通

结果,可以用图像化来表示:

  1.  
    qplot(x=xs, y=ys, data=Y) 
  2.  
     
  3.  
    ggplot(aes(x=xs, y=ys), data=Y)
  4.  
     
  5.  
       geom_point(aes(color=factor(test$cluster))) 

学新通

学新通

 类似的其他情况这里不呈现了,另外r语言提供了EMCluster包可以比较方便的实现EM进行参数估计和结果的误差分析。

  1.  
    ret <- init.EM(Y, nclass = 2)
  2.  
     
  3.  
    em.aic(x=Y,emobj=list(pi = ret$pi, Mu = ret$Mu, LTSigma = ret$LTSigma))#计算结果的AIC

通过比较不同情况的AIC,我们可以筛选出适合的聚类数参数值。

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgcbbak
系列文章
更多 icon
同类精品
更多 icon
继续加载