베이지안 방법론을 통해 gaussian mixture를 푸는 코드이다. Frequentist의 접근법처럼 역시나 group indicator를 latent로 두어, 이에 대한 prior가 또 들어간다. group indicator는 multi-category를 가진 categorical variable이기에, multi-category에 대해 conjugate한 Dirichlet distribution이 prior로 사용된다. 또한 추정해야할 모수가 group indicator, 각 분포의 mu, sigma로 여러개가 있어, MCMC를 통해 한번에 sampling하기 힘들다. 따라서 parameter가 많은 경우 자주 활용되는 Gibbs sampler를 이용하여, 각각의 parameter에 대해 conditional distribution에서 sampling을 한다.
베이지안 방법론으로 적합을 할 시, 각 parameter들에 대한 estimation이 주어진 MCMC sample들에 기반하여 쉽게 이뤄질 수 있기에, 기존 frequentist의 접근에선 수식적으로 까다로웠던 많은 부분들이 해결된다. 그러나 적절하지 못한 prior를 지정해줄 경우, 모델이 좋은 추정을 하지 못할 가능성이 존재한다.
Library
options(repr.plot.width=5, repr.plot.height=5)
library(MCMCpack)#for dirichlet
set.seed(1013)
rm(list = ls())
setwd('C:/Users/admin/내파일/대학원1학기/베이즈/HW2')
data=read.table('Pset2data.txt',header = T)
#Y=data$Gene2
Y=data$Gene1
hist(Y,nclass=100,main="Normal Mixture")
Function
gibbs_sampler=function(K=2, n_iter=1e3, Y=Y,no_library=FALSE){
#K=2; n_iter=2e3; Y=data$Gene1;no_library=TRUE
# create variable for saving each trace-----------------------------------
z_trace=matrix(rep(NA,length(Y)*n_iter),ncol = n_iter)
pi_trace=matrix(rep(NA,K*n_iter),ncol = n_iter)
mu_trace=matrix(rep(NA,K*n_iter),ncol = n_iter)
sig2_trace=matrix(rep(NA,K*n_iter),ncol = n_iter)
llikelihood_trace=matrix(rep(NA,1*n_iter),ncol = n_iter)
# initialize_value-----------------------------------
z_trace[,1] = (ceiling(runif(n = length(Y),0,K)))
pi_trace[,1] = rep(1/K,K)#rdirichlet(1,alpha = rep(1/K,K))
pi_save=pi_trace[,1]
mu_trace[,1] = rnorm(K,0,100)
sig2_trace[,1] = rep(0.3,K)#rinvgamma(K,3,100)
#iterate using Gibbs Sampling-----------------------------------
for(t in 1:n_iter){
if(t==1){
pi_lv = pi_trace[,t]
sig2_lv = sig2_trace[,t]
mu_lv = mu_trace[,t]
z_lv=z_trace[,t]
eta=c(table(z_lv),pi_lv,mu_lv,sig2_lv)
cat('z_lv,pi_lv,mu_lv,sig2_lv is : ',eta,'\n')
next}
#pi step------------
#lv for latest_value
mu_lv = mu_trace[,t-1]
sig2_lv = sig2_trace[,t-1]
z_lv = z_trace[,t-1]
nk=rep(NA,K)
for(i in 1:K){
nk[i]=sum(z_lv==i)
}
#sampling pi, and save (accept prob for Gibbs is 1)
#pi_trace[,t] = rdirichlet(1,nk+1)
#when you cannot use rdirichlet---------------------------------------
if(no_library==TRUE){
stopifnot(K==2)
n1=rbeta(n = 1,shape1 = nk[1]+1/2,shape2 = nk[2]+1/2)
n2=1-n1
pi_trace[,t]=c(n1,n2)
}
#when you cannot use rdirichlet---------------------------------------
else{
pi_trace[,t]= rdirichlet(1,alpha = (nk+1/K))
}
pi_lv = pi_trace[,t]
#mu step----------
condi_mean = rep(NA,K)
for(i in 1:K){
condi_mean[i]=(sum(Y[z_lv==i])/sig2_lv[i])/(nk[i]/sig2_lv[i]+1/100^2)
stopifnot(sum(z_lv==i)==nk[i])
}
condi_sig2 = rep(NA,K)
for(i in 1:K){
condi_sig2[i]=1/(nk[i]/sig2_lv[i]+1/100^2)
}
for(i in 1:K){
#sampling mu, and save (accept prob for Gibbs is 1)
mu_trace[i,t]=rnorm(1, mean = condi_mean[i], sd = sqrt(condi_sig2[i]))
}
mu_lv = mu_trace[,t]
#zstep------------
cat_prob=matrix(rep(NA,K*length(Y)),nrow = K)
for(i in 1:K){
cat_prob[i,]=pi_lv[i]*1/sqrt(sig2_lv[i])*exp(-1/2*(mu_lv[i]-Y)^2/sig2_lv[i])
}
cat_prob=(cat_prob/matrix(rep(colSums(cat_prob),each=K),nrow = K)) #normalize to sum1
for(i in 1:length(Y)){
#i=11
z_trace[i,t] = sample(1:K, size=1,prob=cat_prob[,i],replace=TRUE)
#z_trace[i,t] = which.max(rmultinom(1, 1, cat_prob[,i]))
}
z_lv=z_trace[,t]
nk=rep(NA,K)
for(i in 1:K){
nk[i]=sum(z_lv==i)
}
#sig2 step----------
#sampling sig2, and save (accept prob for Gibbs is 1)
for(i in 1:K){
if(no_library==TRUE){
sig2_trace[i,t]=1/rgamma(1,100+1/2*sum((Y[z_lv==i]-mu_lv[i])^2),length(Y[z_lv==i])/2+3)
}
else{
sig2_trace[i,t]=rinvgamma(1,shape = 100+1/2*sum((Y[z_lv==i]-mu_lv[i])^2),scale = length(Y[z_lv==i])/2+3)
}
}
sig2_lv = sig2_trace[,t]
eta=c(table(z_lv),round(pi_lv,2),round(mu_lv,2),sig2_lv)
llikelihood=0
for(k in 1:K){#k=2
llikelihood=llikelihood+( #(Y-mu_lv[k])^2/sig2_lv[k]
sum(log(( (dnorm(Y,mean = mu_lv[k],sd = sqrt(sig2_lv[k])))*pi_lv[k] )^(z_lv==k)))+
(log((pi_lv[k])^(K-1)))+
(-1/2e4*mu_lv[k]^2)+log(sig2_lv[k]^(-4))+(-100/sig2_lv[k])
)
}
llikelihood_trace[,t]=llikelihood
#print data -----------------------------------
if((t<=10)|(t%%500==0)){#check first 10 steps or every 500 steps
cat('for',t,', z_lv,pi_lv,mu_lv,sig2_lv is : ',eta,'\n')
}
}
res=list(z_trace,pi_trace,mu_trace,sig2_trace,llikelihood_trace)
names(res)=c('z_trace','pi_trace','mu_trace','sig2_trace','llikelihood_trace')
return(res)
}
Simulte & Convergence check
inp_niter=2e3
inp_K=2
set.seed(1014)
#iterage multiple chain to compare result
tmp1=gibbs_sampler(K=inp_K, n_iter=inp_niter, Y=data$Gene1)
tmp2=gibbs_sampler(K=inp_K, n_iter=inp_niter, Y=data$Gene1)
tmp3=gibbs_sampler(K=inp_K, n_iter=inp_niter, Y=data$Gene1)
z_lv,pi_lv,mu_lv,sig2_lv is : 40 32 0.5 0.5 -191.0087 -27.92372 0.3 0.3
for 2 , z_lv,pi_lv,mu_lv,sig2_lv is : 38 34 0.56 0.44 0.22 0.59 0.2165637 0.1509482
for 3 , z_lv,pi_lv,mu_lv,sig2_lv is : 43 29 0.44 0.56 0.02 0.83 0.2510533 0.1417167
for 4 , z_lv,pi_lv,mu_lv,sig2_lv is : 57 15 0.6 0.4 0.05 1.01 0.3747806 0.1155301
for 5 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.81 0.19 0.15 1.89 0.3290877 0.06525928
for 6 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.83 0.17 -0.04 2.48 0.345712 0.07573328
for 7 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.88 0.12 0.07 2.4 0.2992068 0.07914449
for 8 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.8 0.2 0.2 2.41 0.3677189 0.09115946
for 9 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.86 0.14 -0.06 2.33 0.3403325 0.09380707
for 10 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.85 0.15 0.05 2.52 0.4521353 0.0782955
for 500 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.92 0.08 0.19 2.5 0.3426829 0.07064485
for 1000 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.84 0.16 0.08 2.45 0.3653833 0.07525163
for 1500 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.81 0.19 -0.01 2.51 0.255525 0.09068538
for 2000 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.88 0.12 -0.01 2.51 0.300147 0.07983018
z_lv,pi_lv,mu_lv,sig2_lv is : 35 37 0.5 0.5 182.5481 115.7144 0.3 0.3
for 2 , z_lv,pi_lv,mu_lv,sig2_lv is : 30 42 0.47 0.53 0.39 0.33 0.1380176 0.2655488
for 3 , z_lv,pi_lv,mu_lv,sig2_lv is : 22 50 0.49 0.51 0.53 0.29 0.1321068 0.2340113
for 4 , z_lv,pi_lv,mu_lv,sig2_lv is : 27 45 0.25 0.75 0.06 0.56 0.1543063 0.2010244
for 5 , z_lv,pi_lv,mu_lv,sig2_lv is : 28 44 0.35 0.65 -0.03 0.69 0.1593382 0.205812
for 6 , z_lv,pi_lv,mu_lv,sig2_lv is : 41 31 0.46 0.54 0.02 0.55 0.2156582 0.1554942
for 7 , z_lv,pi_lv,mu_lv,sig2_lv is : 44 28 0.48 0.52 -0.13 0.93 0.2514632 0.143165
for 8 , z_lv,pi_lv,mu_lv,sig2_lv is : 60 12 0.68 0.32 -0.03 1.18 0.261554 0.08524478
for 9 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.79 0.21 -0.01 2.24 0.3305866 0.07985622
for 10 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.84 0.16 0.01 2.31 0.3925597 0.07986116
for 500 , z_lv,pi_lv,mu_lv,sig2_lv is : 63 9 0.88 0.12 0.04 2.63 0.3045323 0.07649907
for 1000 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.88 0.12 0.06 2.42 0.3116635 0.069078
for 1500 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.89 0.11 0.11 2.51 0.3665746 0.09550593
for 2000 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.84 0.16 0.12 2.39 0.3540361 0.07251946
z_lv,pi_lv,mu_lv,sig2_lv is : 36 36 0.5 0.5 -5.090603 -59.35514 0.3 0.3
for 2 , z_lv,pi_lv,mu_lv,sig2_lv is : 38 34 0.48 0.52 0.31 0.52 0.1765433 0.1678907
for 3 , z_lv,pi_lv,mu_lv,sig2_lv is : 53 19 0.52 0.48 0.03 0.76 0.3007423 0.09123187
for 4 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.8 0.2 0.03 1.5 0.336586 0.06957606
for 5 , z_lv,pi_lv,mu_lv,sig2_lv is : 63 9 0.83 0.17 0.17 2.42 0.3289166 0.08441534
for 6 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.87 0.13 0.03 2.62 0.3206282 0.06735946
for 7 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.77 0.23 0.12 2.53 0.3357813 0.07837687
for 8 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.88 0.12 0.13 2.22 0.3373437 0.084334
for 9 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.82 0.18 0.04 2.6 0.2830813 0.08200646
for 10 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.78 0.22 0.03 2.67 0.3110761 0.07652402
for 500 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.9 0.1 0.04 2.45 0.3481991 0.07615492
for 1000 , z_lv,pi_lv,mu_lv,sig2_lv is : 62 10 0.92 0.08 0.07 2.56 0.3003156 0.07968798
for 1500 , z_lv,pi_lv,mu_lv,sig2_lv is : 61 11 0.85 0.15 0.07 2.48 0.3360475 0.101472
for 2000 , z_lv,pi_lv,mu_lv,sig2_lv is : 63 9 0.88 0.12 0.13 2.54 0.3583726 0.05720203
target=c('z_trace','pi_trace','mu_trace','sig2_trace','llikelihood_trace')[3]
trace_mcmc1=mcmc(t(tmp1[[target]])[500:inp_niter,])
#summary(trace_mcmc1)
plot(trace_mcmc1)
trace_mcmc2=mcmc(t(tmp2[[target]])[500:inp_niter,])
plot(trace_mcmc2)
trace_mcmc3=mcmc(t(tmp3[[target]])[500:inp_niter,])
plot(trace_mcmc3)
수렴을 확인하기 위해 chain을 여러번 돌린다. 각 group indicator (여기선 2개의 group을 주었으니 1,2)는 임의적인 것이라, 순서가 바뀔 수 있다. 수렴을 확인하기 위한 다양한 방법이 있지만, 그중에서 multiple chain간의 with in variance와 between variance를 비교하는 Gelman Rubin statistics를 이용하였다. 보통 1.1이하의 값을 띄면 수렴했다고 판단한다.
# convergence check with Gelman Rubin statistics
conv_m1=mcmc(t(tmp1[[target]])[500:inp_niter,2])
conv_m2=mcmc(t(tmp2[[target]])[500:inp_niter,2])
conv_m3=mcmc(t(tmp3[[target]])[500:inp_niter,2])
#combinedchains = mcmc.list(trace_mcmc1, trace_mcmc2,trace_mcmc3)
combinedchains = mcmc.list(conv_m1,conv_m2,conv_m3)
plot(combinedchains)
gelman.diag(combinedchains)
gelman.plot(combinedchains)
Potential scale reduction factors:
Point est. Upper C.I.
[1,] 1 1
### generate samples using posterior inference
one_tmp=tmp3
posterior_mean_pi=colMeans(t(one_tmp[['pi_trace']]))
posterior_mean_mu=colMeans(t(one_tmp[['mu_trace']]))
posterior_mean_sig2=colMeans(t(one_tmp[['sig2_trace']]))
y_sim=rep(0,length(Y))
for(i in 1:length(Y)){
z_tmp = sample (seq(1,inp_K), size=1, replace=T, prob=posterior_mean_pi)
for(k in 1:inp_K){
y_sim[i]=rnorm(1,mean = posterior_mean_mu[z_tmp],sd = sqrt(posterior_mean_sig2[z_tmp]))
}
}
par(mfrow=c(1,2))
hist(Y,nclass=50)
hist(y_sim,nclass=50)
덧붙여 EM algorithm으로 동일한 데이터에 대해 적합해보았다.
###data check with traditional K-mixture-----------------------------
K_mixture=function(y=data$Gene1,inp_K=3){
n=length(y)
#init value
K=inp_K
pi=rep(1/K,K)
mu=rnorm(K,sd = 30)#c(mu1,mu2,mu3)
sig2=2 # common variance
eta=c(pi,mu,sig2)
print(eta)
repeat{
eta0=eta
term=matrix(rep(NA,length(pi)*n), ncol = length(pi)) # saving value of each distn in each column
## E-step
for(i in 1:length(pi)){
term[,i]=pi[i]*dnorm(y,mu[i],sd=sqrt(sig2))# common variance
}
z=term/rowSums(term)
## M-step
for(i in 1:length(pi)){ #update the parameters which maximize the Q-function
pi[i]=sum(z[,i])/n
mu[i]=sum(z[,i]*y)/sum(z[,i])
}
mu_mat=matrix(rep(mu,n),ncol = length(pi),byrow = T)
sig2=sum(z*(y-mu_mat)^2)/n # sum for all i,k (common variance)
eta=c(pi,mu,sig2)
## Convergence criteria
diff=(eta0-eta)^2
print(c(eta))#,logL))
if(sum(diff)<1e-7) {cat('converge!\n');break}
}
posterior_mean_pi=pi
posterior_mean_mu=mu
posterior_mean_sig2=sig2
Y=y
y_sim=rep(0,length(Y))
for(i in 1:length(Y)){
z_tmp = sample (seq(1,inp_K), size=1, replace=T, prob=posterior_mean_pi)
for(k in 1:inp_K){
y_sim[i]=rnorm(1,mean = posterior_mean_mu[z_tmp],sd = sqrt(posterior_mean_sig2))
}
}
par(mfrow=c(2,1))
hist(Y,nclass=50)
hist(y_sim,nclass=50)
return(eta)
}
mixture_res=K_mixture(y = data$Gene1,inp_K=2)
round(mixture_res,3)#c(pi,mu,sig2)
[1] 0.5000000 0.5000000 -0.1572504 4.3961548 2.0000000
[1] 0.8904799 0.1095201 0.1807714 2.3333463 0.3589446
[1] 0.85349044 0.14650956 0.06586353 2.45927732 0.09452766
[1] 0.84722265 0.15277735 0.05209006 2.43746643 0.07434040
[1] 0.84722223 0.15277777 0.05208921 2.43746456 0.07433937
converge!
Comments