![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_SIGLAM_H_ 00044 #define _GURLS_SIGLAM_H_ 00045 00046 #include <cstdio> 00047 #include <cstring> 00048 #include <iostream> 00049 #include <cmath> 00050 #include <algorithm> 00051 #include <set> 00052 00053 #include "gurls++/options.h" 00054 #include "gurls++/optlist.h" 00055 #include "gurls++/gmat2d.h" 00056 #include "gurls++/gvec.h" 00057 #include "gurls++/gmath.h" 00058 00059 #include "gurls++/paramsel.h" 00060 #include "gurls++/perf.h" 00061 #include "gurls++/rbfkernel.h" 00062 #include "gurls++/loocvdual.h" 00063 00064 namespace gurls { 00065 00071 template <typename T> 00072 class ParamSelSiglam: public ParamSelection<T>{ 00073 00074 public: 00092 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00093 }; 00094 00095 template <typename T> 00096 GurlsOptionsList* ParamSelSiglam<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt) 00097 { 00098 // [n,T] = size(y); 00099 const unsigned long t = Y.cols(); 00100 00101 // GurlsOptionsList* kernel_old = NULL; 00102 // if (opt.hasOpt("kernel")) 00103 // { 00104 // kernel_old = GurlsOptionsList::dynacast(opt.getOpt("kernel")); 00105 // opt.removeOpt("kernel", false); 00106 // } 00107 00108 GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested"); 00109 nestedOpt->copyOpt("nlambda", opt); 00110 nestedOpt->copyOpt("hoperf", opt); 00111 nestedOpt->copyOpt("smallnumber", opt); 00112 00113 GurlsOptionsList* kernel = new GurlsOptionsList("kernel"); 00114 kernel->addOpt("type", "rbf"); 00115 nestedOpt->addOpt("kernel", kernel); 00116 00117 00118 GurlsOptionsList* paramsel; 00119 00120 if(opt.hasOpt("paramsel")) 00121 { 00122 GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp"); 00123 tmp_opt->copyOpt("paramsel", opt); 00124 00125 paramsel = GurlsOptionsList::dynacast(tmp_opt->getOpt("paramsel")); 00126 tmp_opt->removeOpt("paramsel", false); 00127 delete tmp_opt; 00128 00129 paramsel->removeOpt("lambdas"); 00130 paramsel->removeOpt("sigma"); 00131 } 00132 else 00133 paramsel = new GurlsOptionsList("paramsel"); 00134 00135 00136 gMat2D<T>* dist = new gMat2D<T>(X.rows(), X.rows()); 00137 00138 // if ~isfield(opt.kernel,'distance') 00139 if(!kernel->hasOpt("distance")) 00140 // opt.kernel.distance = squareform(pdist(X)); 00141 { 00142 squareform<T>(X.getData(), X.rows(), X.cols(), dist->getData(), X.rows()); 00143 00144 T *distSquared = new T[X.rows()*X.rows()]; 00145 copy(distSquared , dist->getData(), X.rows()*X.rows()); 00146 00147 mult<T>(distSquared, distSquared, dist->getData(), X.rows()*X.rows()); 00148 00149 kernel->addOpt("distance", new OptMatrix<gMat2D<T> >(*dist)); 00150 00151 delete [] distSquared; 00152 } 00153 else 00154 dist = &(kernel->getOptValue<OptMatrix<gMat2D<T> > >("distance")); 00155 00156 00157 // if ~isfield(opt,'sigmamin') 00158 if(!opt.hasOpt("sigmamin")) 00159 { 00160 // %D = sort(opt.kernel.distance); 00161 // %opt.sigmamin = median(D(2,:)); 00162 // D = sort(squareform(opt.kernel.distance)); 00163 int d_len = X.rows()*(X.rows()-1)/2; 00164 T* distY = new T[d_len]; 00165 // squareform<T>(dist->getData(), dist->rows(), dist->cols(), distY, 1); 00166 00167 const int size = dist->cols(); 00168 T* it = distY; 00169 00170 for(int i=1; i< size; it+=i, ++i) 00171 copy(it , dist->getData()+(i*size), i); 00172 00173 std::sort(distY, distY + d_len); 00174 // firstPercentile = round(0.01*numel(D)+0.5); 00175 00176 int firstPercentile = gurls::round( (T)0.01 * d_len + (T)0.5) -1; 00177 00178 // opt.sigmamin = D(firstPercentile); 00179 nestedOpt->addOpt("sigmamin", new OptNumber(sqrt( distY[firstPercentile]) )); 00180 00181 delete [] distY; 00182 } 00183 else 00184 { 00185 nestedOpt->addOpt("sigmamin", new OptNumber(opt.getOptAsNumber("sigmamin"))); 00186 } 00187 00188 T sigmamin = static_cast<T>(nestedOpt->getOptAsNumber("sigmamin")); 00189 00190 // if ~isfield(opt,'sigmamax') 00191 if(!opt.hasOpt("sigmamax")) 00192 { 00193 // %D = sort(opt.kernel.distance); 00194 // %opt.sigmamax = median(D(n,:)); 00195 T mAx = *(std::max_element(dist->getData(),dist->getData()+ dist->getSize())); 00196 00197 // opt.sigmamax = max(max(opt.kernel.distance)); 00198 nestedOpt->addOpt("sigmamax", new OptNumber( sqrt( mAx ))); 00199 } 00200 else 00201 { 00202 nestedOpt->addOpt("sigmamax", new OptNumber(opt.getOptAsNumber("sigmamax"))); 00203 } 00204 00205 T sigmamax = static_cast<T>(nestedOpt->getOptAsNumber("sigmamax")); 00206 00207 // if opt.sigmamin <= 0 00208 if( le(sigmamin, (T)0.0) ) 00209 { 00210 // opt.sigmamin = eps; 00211 nestedOpt->removeOpt("sigmamin"); 00212 nestedOpt->addOpt("sigmamin", new OptNumber(std::numeric_limits<T>::epsilon())); 00213 sigmamin = std::numeric_limits<T>::epsilon(); 00214 } 00215 00216 // if opt.sigmamin <= 0 00217 if( le(sigmamin, (T)0.0)) 00218 { 00219 // opt.sigmamax = eps; 00220 nestedOpt->removeOpt("sigmamax"); 00221 nestedOpt->addOpt("sigmamax", new OptNumber(std::numeric_limits<T>::epsilon())); 00222 sigmamax = std::numeric_limits<T>::epsilon(); 00223 } 00224 00225 unsigned long nlambda = static_cast<unsigned long>(opt.getOptAsNumber("nlambda")); 00226 unsigned long nsigma = static_cast<unsigned long>( opt.getOptAsNumber("nsigma")); 00227 00228 T q = pow( sigmamax/sigmamin, static_cast<T>(1.0/(nsigma-1.0))); 00229 00230 // LOOSQE = zeros(opt.nsigma,opt.nlambda,T); 00231 //T* LOOSQE = new T[nsigma*nlambda*t]; 00232 T* perf = new T[nlambda]; 00233 00234 // sigmas = zeros(1,opt.nsigma); 00235 // for i = 1:opt.nsigma 00236 00237 KernelRBF<T> rbfkernel; 00238 ParamSelLoocvDual<T> loocvdual; 00239 00240 T* work = new T[std::max(nlambda, t+1)]; 00241 T maxTmp = (T)-1.0; 00242 int m = -1; 00243 T guess = (T)-1.0; 00244 00245 for(unsigned long i=0; i<nsigma; ++i) 00246 { 00247 nestedOpt->addOpt("paramsel", paramsel); 00248 00249 paramsel->removeOpt("sigma"); 00250 paramsel->addOpt("sigma", new OptNumber( sigmamin * pow(q, (T)i))); 00251 00252 // opt.kernel = kernel_rbf(X,y,opt); 00253 GurlsOptionsList* retKernel = rbfkernel.execute(X, Y, *nestedOpt); 00254 00255 nestedOpt->removeOpt("kernel"); 00256 nestedOpt->addOpt("kernel", retKernel); 00257 00258 nestedOpt->removeOpt("paramsel", false); 00259 00260 // paramsel = paramsel_loocvdual(X,y,opt); 00261 GurlsOptionsList* ret_paramsel = loocvdual.execute(X, Y, *nestedOpt); 00262 00263 00264 gMat2D<T> &looe_mat = ret_paramsel->getOptValue<OptMatrix<gMat2D<T> > >("perf"); 00265 00266 // LOOSQE(i,:,:) = paramsel.looe{1}; 00267 // guesses(i,:) = paramsel.guesses; 00268 gMat2D<T> &guesses_mat = ret_paramsel->getOptValue<OptMatrix<gMat2D<T> > >("guesses"); 00269 00270 for(unsigned long j=0; j<nlambda; ++j) 00271 { 00272 perf[j] = 0; 00273 00274 T* end = looe_mat.getData()+looe_mat.getSize(); 00275 for(T* it = looe_mat.getData()+j; it< end ; it+=nlambda) 00276 perf[j] += *it; 00277 } 00278 00279 unsigned long mm = std::max_element(perf, perf + nlambda) - perf; 00280 00281 if( gt(perf[mm], maxTmp)) 00282 { 00283 maxTmp = perf[mm]; 00284 m = i; 00285 guess = guesses_mat.getData()[mm*guesses_mat.rows()]; 00286 } 00287 00288 delete ret_paramsel; 00289 } 00290 00291 delete [] work; 00292 delete [] perf; 00293 delete nestedOpt; 00294 00295 // M = sum(LOOSQE,3); % sum over classes 00296 // 00297 // [dummy,i] = max(M(:)); 00298 // [m,n] = ind2sub(size(M),i); 00299 // 00300 // % opt sigma 00301 // vout.sigma = opt.sigmamin*(q^m); 00302 00303 00304 paramsel->removeOpt("sigma"); 00305 paramsel->addOpt("sigma", new OptNumber( sigmamin * pow(q,m) )); 00306 00307 // % opt lambda 00308 // vout.lambdas = guesses(m,n)*ones(1,T); 00309 00310 gMat2D<T> *LAMBDA = new gMat2D<T>(1, t); 00311 set(LAMBDA->getData(), guess, t); 00312 00313 paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*LAMBDA)); 00314 00315 return paramsel; 00316 00317 } 00318 00319 00320 } 00321 00322 #endif // _GURLS_SIGLAM_H_