IT++ Logo Newcom Logo

gmm.cpp

Go to the documentation of this file.
00001 
00033 #include <itpp/srccode/gmm.h>
00034 #include <itpp/srccode/vqtrain.h>
00035 #include <itpp/base/matfunc.h>
00036 #include <itpp/base/specmat.h>
00037 #include <itpp/base/random.h>
00038 #include <itpp/base/timing.h>
00039 #include <itpp/base/elmatfunc.h>
00040 #include <iostream>
00041 #include <fstream>
00042 
00043 
00044 namespace itpp {
00045 
00046   GMM::GMM()
00047   {
00048     d=0;
00049     M=0;
00050   }
00051 
00052   GMM::GMM(std::string filename)
00053   {
00054     load(filename);
00055   }
00056 
00057   GMM::GMM(int M_in, int d_in)
00058   {
00059     M=M_in;
00060     d=d_in;
00061     m=zeros(M*d);
00062     sigma=zeros(M*d);
00063     w=1./M*ones(M);
00064 
00065     for (int i=0;i<M;i++) {
00066       w(i)=1.0/M;
00067     }
00068     compute_internals();
00069   }
00070 
00071   void GMM::init_from_vq(const vec &codebook, int dim)
00072   {
00073 
00074     mat         C(dim,dim);
00075     int         i;
00076     vec         v;
00077 
00078     d=dim;
00079     M=codebook.length()/dim;
00080 
00081     m=codebook;
00082     w=ones(M)/double(M);
00083 
00084     C.clear();
00085     for (i=0;i<M;i++) {
00086       v=codebook.mid(i*d,d);
00087       C=C+outer_product(v,v);
00088     }
00089     C=1./M*C;
00090     sigma.set_length(M*d);
00091     for (i=0;i<M;i++) {
00092       sigma.replace_mid(i*d,diag(C));
00093     }
00094 
00095     compute_internals();
00096   }
00097 
00098   //void GMM::init(const vec &m_in, const vec &sigma_in, const vec &w_in)
00099   //{
00100   //    m=m_in;
00101   //    sigma=sigma_in;
00102   //    w=w_in;
00103   //
00104   //    compute_internals();
00105   //}
00106   void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in)
00107   {
00108     int         i,j;
00109     d=m_in.rows();
00110     M=m_in.cols();
00111 
00112     m.set_length(M*d);
00113     sigma.set_length(M*d);
00114     for (i=0;i<M;i++) {
00115       for (j=0;j<d;j++) {
00116         m(i*d+j)=m_in(j,i);
00117         sigma(i*d+j)=sigma_in(j,i);
00118       }
00119     }
00120     w=w_in;
00121 
00122     compute_internals();
00123   }
00124 
00125   void GMM::set_mean(const mat &m_in)
00126   {
00127     int         i,j;
00128 
00129     d=m_in.rows();
00130     M=m_in.cols();
00131 
00132     m.set_length(M*d);
00133     for (i=0;i<M;i++) {
00134       for (j=0;j<d;j++) {
00135         m(i*d+j)=m_in(j,i);
00136       }
00137     }
00138     compute_internals();
00139   }
00140 
00141   void GMM::set_mean(int i, const vec &means, bool compflag)
00142   {
00143     m.replace_mid(i*length(means),means); 
00144     if (compflag) compute_internals(); 
00145   }
00146 
00147   void GMM::set_covariance(const mat &sigma_in)
00148   {
00149     int         i,j;
00150 
00151     d=sigma_in.rows();
00152     M=sigma_in.cols();
00153 
00154     sigma.set_length(M*d);
00155     for (i=0;i<M;i++) {
00156       for (j=0;j<d;j++) {
00157         sigma(i*d+j)=sigma_in(j,i);
00158       }
00159     }
00160     compute_internals();
00161   }
00162 
00163   void GMM::set_covariance(int i, const vec &covariances, bool compflag) 
00164   {
00165     sigma.replace_mid(i*length(covariances),covariances); 
00166     if (compflag) compute_internals(); 
00167   }
00168 
00169   void GMM::marginalize(int d_new)
00170   {
00171     it_error_if(d_new>d,"GMM.marginalize: cannot change to a larger dimension");
00172 
00173     vec         mnew(d_new*M),sigmanew(d_new*M);
00174     int         i,j;
00175 
00176     for (i=0;i<M;i++) {
00177       for (j=0;j<d_new;j++) {
00178         mnew(i*d_new+j)=m(i*d+j);
00179         sigmanew(i*d_new+j)=sigma(i*d+j);
00180       }
00181     }
00182     m=mnew;
00183     sigma=sigmanew;
00184     d=d_new;
00185 
00186     compute_internals(); 
00187   }
00188 
00189   void GMM::join(const GMM &newgmm)
00190   {
00191     if (d==0) {
00192       w=newgmm.w;
00193       m=newgmm.m;
00194       sigma=newgmm.sigma;
00195       d=newgmm.d;
00196       M=newgmm.M;
00197     } else {
00198       it_error_if( d!=newgmm.d,"GMM.join: cannot join GMMs of different dimension");
00199 
00200       w=concat(double(M)/(M+newgmm.M)*w,double(newgmm.M)/(M+newgmm.M)*newgmm.w);
00201       w=w/sum(w);
00202       m=concat(m,newgmm.m);
00203       sigma=concat(sigma,newgmm.sigma);
00204 
00205       M=M+newgmm.M;
00206     }
00207     compute_internals(); 
00208   }
00209 
00210   void GMM::clear()
00211   {
00212     w.set_length(0);
00213     m.set_length(0);
00214     sigma.set_length(0);
00215     d=0;
00216     M=0;
00217   }
00218 
00219   void GMM::save(std::string filename)
00220   {
00221     std::ofstream       f(filename.c_str());
00222     int                 i,j;
00223 
00224     f << M << " " << d << std::endl ;
00225     for (i=0;i<w.length();i++) {
00226       f << w(i) << std::endl ;
00227     }
00228     for (i=0;i<M;i++) {
00229       f << m(i*d) ;
00230       for (j=1;j<d;j++) {
00231         f << " " << m(i*d+j) ;
00232       }
00233       f << std::endl ;
00234     }
00235     for (i=0;i<M;i++) {
00236       f << sigma(i*d) ;
00237       for (j=1;j<d;j++) {
00238         f << " " << sigma(i*d+j) ;
00239       }
00240       f << std::endl ;
00241     }
00242   }
00243 
00244   void GMM::load(std::string filename)
00245   {
00246     std::ifstream       GMMFile(filename.c_str());
00247     long                i,j;
00248 
00249     it_error_if(!GMMFile,std::string("GMM::load : cannot open file ")+filename);
00250 
00251     GMMFile >> M >> d ;
00252 
00253 
00254     w.set_length(M);
00255     for (i=0;i<M;i++) {
00256       GMMFile >> w(i) ;
00257     }   
00258     m.set_length(M*d);
00259     for (i=0;i<M;i++) {
00260       for (j=0;j<d;j++) {
00261         GMMFile >> m(i*d+j) ;
00262       }
00263     }   
00264     sigma.set_length(M*d);
00265     for (i=0;i<M;i++) {
00266       for (j=0;j<d;j++) {
00267         GMMFile >> sigma(i*d+j) ;
00268       }
00269     }   
00270     compute_internals();
00271     std::cout << "  mixtures:" << M << "  dim:" << d << std::endl ;
00272   }
00273 
00274   double GMM::likelihood(const vec &x)
00275   {
00276     double      fx=0;
00277     int         i;
00278 
00279     for (i=0;i<M;i++) {
00280       fx+=w(i)*likelihood_aposteriori(x, i);
00281     }
00282     return fx;
00283   }
00284 
00285   vec GMM::likelihood_aposteriori(const vec &x)
00286   {
00287     vec         v(M);
00288     int         i;
00289 
00290     for (i=0;i<M;i++) {
00291       v(i)=w(i)*likelihood_aposteriori(x, i);
00292     }
00293     return v;
00294   }
00295 
00296   double GMM::likelihood_aposteriori(const vec &x, int mixture)
00297   {
00298     int         j;
00299     double      s;
00300 
00301     it_error_if(d!=x.length(),"GMM::likelihood_aposteriori : dimensions does not match");
00302     s=0;
00303     for (j=0;j<d;j++) {
00304       s+=normexp(mixture*d+j)*sqr(x(j)-m(mixture*d+j));
00305     }
00306     return normweight(mixture)*std::exp(s);;
00307   }
00308 
00309   void GMM::compute_internals()
00310   {
00311     int         i,j;
00312     double      s;
00313     double      constant=1.0/std::pow(2*pi,d/2.0);
00314 
00315     normweight.set_length(M);
00316     normexp.set_length(M*d);
00317 
00318     for (i=0;i<M;i++) {
00319       s=1;
00320       for (j=0;j<d;j++) {
00321         normexp(i*d+j)=-0.5/sigma(i*d+j);  // check time
00322         s*=sigma(i*d+j);
00323       }
00324       normweight(i) = constant/std::sqrt(s);
00325     }
00326 
00327   }
00328 
00329   vec GMM::draw_sample()
00330   {
00331     static bool first=true;
00332     static vec  cumweight;
00333     double      u=randu();
00334     int         k;
00335 
00336     if (first) {
00337       first=false;
00338       cumweight=cumsum(w);
00339       it_error_if(std::abs(cumweight(length(cumweight)-1)-1)>1e-6,"weight does not sum to 0");
00340       cumweight(length(cumweight)-1)=1;
00341     }
00342     k=0;
00343     while (u>cumweight(k)) k++;
00344 
00345     return elem_mult(sqrt(sigma.mid(k*d,d)),randn(d))+m.mid(k*d,d);
00346   }
00347 
00348   GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE)
00349   {
00350     mat                 mean;
00351     int                 i,j,d=TrainingData(0).length();
00352     vec                 sig;
00353     GMM                 gmm(M,d);
00354     vec                 m(d*M);
00355     vec                 sigma(d*M);
00356     vec                 w(M);
00357     vec                 normweight(M);
00358     vec                 normexp(d*M);
00359     double              LL=0,LLold,fx;
00360     double              constant=1.0/std::pow(2*pi,d/2.0);
00361     int                 T=TrainingData.length();
00362     vec                 x1;
00363     int                 t,n;
00364     vec                 msum(d*M);
00365     vec                 sigmasum(d*M);
00366     vec                 wsum(M);
00367     vec                 p_aposteriori(M);
00368     vec                 x2;
00369     double              s;
00370     vec                 temp1,temp2;
00371     //double            MINIMUM_VARIANCE=0.03;
00372 
00373     //-----------initialization-----------------------------------
00374 
00375     mean=vqtrain(TrainingData,M,200000,0.5,VERBOSE);
00376     for (i=0;i<M;i++) gmm.set_mean(i,mean.get_col(i),false);
00377     //  for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false);
00378     sig=zeros(d);
00379     for (i=0;i<TrainingData.length();i++) sig+=sqr(TrainingData(i));
00380     sig/=TrainingData.length();
00381     for (i=0;i<M;i++) gmm.set_covariance(i,0.5*sig,false);
00382 
00383     gmm.set_weight(1.0/M*ones(M));
00384 
00385     //-----------optimization-----------------------------------
00386 
00387     tic();
00388     for (i=0;i<M;i++) {
00389       temp1=gmm.get_mean(i);
00390       temp2=gmm.get_covariance(i);
00391       for (j=0;j<d;j++) {
00392         m(i*d+j)=temp1(j);
00393         sigma(i*d+j)=temp2(j);
00394       }
00395       w(i)=gmm.get_weight(i);
00396     }
00397     for (n=0;n<NOITER;n++) {
00398       for (i=0;i<M;i++) {
00399         s=1;
00400         for (j=0;j<d;j++) {
00401           normexp(i*d+j)=-0.5/sigma(i*d+j);  // check time
00402           s*=sigma(i*d+j);
00403         }
00404         normweight(i) = constant*w(i)/std::sqrt(s);
00405       }
00406       LLold=LL;
00407       wsum.clear();
00408       msum.clear();
00409       sigmasum.clear();
00410       LL=0;
00411       for (t=0;t<T;t++) {
00412         x1=TrainingData(t);
00413         x2=sqr(x1);
00414         fx=0;
00415         for (i=0;i<M;i++) {
00416           s=0;
00417           for (j=0;j<d;j++) {
00418             s+=normexp(i*d+j)*sqr(x1(j)-m(i*d+j));
00419           }
00420           p_aposteriori(i)=normweight(i)*std::exp(s);
00421           fx+=p_aposteriori(i);
00422         }
00423         p_aposteriori/=fx;
00424         LL=LL+std::log(fx);
00425 
00426         for (i=0;i<M;i++) {
00427           wsum(i)+=p_aposteriori(i);
00428           for (j=0;j<d;j++) {
00429             msum(i*d+j)+=p_aposteriori(i)*x1(j);
00430             sigmasum(i*d+j)+=p_aposteriori(i)*x2(j);
00431           }
00432         }
00433       }
00434       for (i=0;i<M;i++) {
00435         for (j=0;j<d;j++) {
00436           m(i*d+j)=msum(i*d+j)/wsum(i);
00437           sigma(i*d+j)=sigmasum(i*d+j)/wsum(i)-sqr(m(i*d+j));
00438         }
00439         w(i)=wsum(i)/T;
00440       }
00441       LL=LL/T;
00442 
00443       if (std::abs((LL-LLold)/LL) < 1e-6) break;
00444       if (VERBOSE) {
00445         std::cout << n << ":   " << LL << "   " << std::abs((LL-LLold)/LL) << "   " << toc() <<  std::endl ;
00446         std::cout << "---------------------------------------" << std::endl ;
00447         tic();
00448       } else {
00449         std::cout << n << ": LL =  " << LL << "   " << std::abs((LL-LLold)/LL) << "\r" ;std::cout.flush();
00450       }
00451     }
00452     for (i=0;i<M;i++) {
00453       gmm.set_mean(i,m.mid(i*d,d),false);
00454       gmm.set_covariance(i,sigma.mid(i*d,d),false);
00455     }
00456     gmm.set_weight(w);
00457     return gmm;
00458   }
00459 
00460 } // namespace itpp
SourceForge Logo

Generated on Fri Jun 8 00:37:35 2007 for IT++ by Doxygen 1.5.2