![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_SIGLAMHO_H_ 00044 #define _GURLS_SIGLAMHO_H_ 00045 00046 #include <cstdio> 00047 #include <cstring> 00048 #include <iostream> 00049 #include <cmath> 00050 #include <algorithm> 00051 #include <set> 00052 00053 #include "gurls++/options.h" 00054 #include "gurls++/optlist.h" 00055 #include "gurls++/gmat2d.h" 00056 #include "gurls++/gvec.h" 00057 #include "gurls++/gmath.h" 00058 00059 #include "gurls++/paramsel.h" 00060 #include "gurls++/perf.h" 00061 #include "gurls++/rbfkernel.h" 00062 #include "gurls++/hodual.h" 00063 00064 namespace gurls { 00065 00071 template <typename T> 00072 class ParamSelSiglamHo: public ParamSelection<T>{ 00073 00074 public: 00075 00095 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00096 }; 00097 00098 template <typename T> 00099 GurlsOptionsList *ParamSelSiglamHo<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt) 00100 { 00101 // [n,T] = size(y); 00102 const unsigned long t = Y.cols(); 00103 00104 00105 GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested"); 00106 nestedOpt->copyOpt("nlambda", opt); 00107 nestedOpt->copyOpt("nholdouts", opt); 00108 nestedOpt->copyOpt("hoperf", opt); 00109 nestedOpt->copyOpt("smallnumber", opt); 00110 nestedOpt->copyOpt("split", opt); 00111 00112 GurlsOptionsList* kernel = new GurlsOptionsList("kernel"); 00113 kernel->addOpt("type", "rbf"); 00114 nestedOpt->addOpt("kernel", kernel); 00115 00116 00117 GurlsOptionsList* paramsel; 00118 00119 if(opt.hasOpt("paramsel")) 00120 { 00121 GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp"); 00122 tmp_opt->copyOpt("paramsel", opt); 00123 00124 paramsel = GurlsOptionsList::dynacast(tmp_opt->getOpt("paramsel")); 00125 tmp_opt->removeOpt("paramsel", false); 00126 delete tmp_opt; 00127 00128 paramsel->removeOpt("lambdas"); 00129 paramsel->removeOpt("sigma"); 00130 } 00131 else 00132 paramsel = new GurlsOptionsList("paramsel"); 00133 00134 00135 gMat2D<T>* dist = NULL; 00136 00137 // if ~isfield(opt.kernel,'distance') 00138 if(!kernel->hasOpt("distance")) 00139 // opt.kernel.distance = squareform(pdist(X)); 00140 { 00141 dist = new gMat2D<T>(X.rows(), X.rows()); 00142 00143 squareform<T>(X.getData(), X.rows(), X.cols(), dist->getData(), X.rows()); 00144 00145 T *distSquared = new T[X.rows()*X.rows()]; 00146 copy(distSquared , dist->getData(), X.rows()*X.rows()); 00147 00148 mult<T>(distSquared, distSquared, dist->getData(), X.rows()*X.rows()); 00149 00150 kernel->addOpt("distance", new OptMatrix<gMat2D<T> >(*dist)); 00151 delete [] distSquared; 00152 } 00153 else 00154 { 00155 GurlsOption *dist_opt = kernel->getOpt("distance"); 00156 dist = &(OptMatrix<gMat2D<T> >::dynacast(dist_opt))->getValue(); 00157 } 00158 00159 00160 // if ~isfield(opt,'sigmamin') 00161 if(!opt.hasOpt("sigmamin")) 00162 { 00163 // %D = sort(opt.kernel.distance); 00164 // %opt.sigmamin = median(D(2,:)); 00165 // D = sort(squareform(opt.kernel.distance)); 00166 int d_len = X.rows()*(X.rows()-1)/2; 00167 T* distY = new T[d_len]; 00168 // squareform<T>(dist->getData(), dist->rows(), dist->cols(), distY, 1); 00169 00170 const int size = dist->cols(); 00171 T* it = distY; 00172 00173 for(int i=1; i< size; it+=i, ++i) 00174 copy(it , dist->getData()+(i*size), i); 00175 00176 std::sort(distY, distY + d_len); 00177 // firstPercentile = round(0.01*numel(D)+0.5); 00178 00179 int firstPercentile = gurls::round((T)0.01 * d_len + (T)0.5)-1; 00180 00181 // opt.sigmamin = D(firstPercentile); 00182 nestedOpt->addOpt("sigmamin", new OptNumber(sqrt( distY[firstPercentile]) )); 00183 00184 delete [] distY; 00185 } 00186 else 00187 { 00188 nestedOpt->addOpt("sigmamin", new OptNumber(opt.getOptAsNumber("sigmamin"))); 00189 } 00190 00191 T sigmamin = static_cast<T>(nestedOpt->getOptAsNumber("sigmamin")); 00192 00193 // if ~isfield(opt,'sigmamax') 00194 if(!opt.hasOpt("sigmamax")) 00195 { 00196 // %D = sort(opt.kernel.distance); 00197 // %opt.sigmamax = median(D(n,:)); 00198 T mAx = *(std::max_element(dist->getData(),dist->getData()+ dist->getSize())); 00199 00200 // opt.sigmamax = max(max(opt.kernel.distance)); 00201 nestedOpt->addOpt("sigmamax", new OptNumber( sqrt( mAx ))); 00202 } 00203 else 00204 { 00205 nestedOpt->addOpt("sigmamax", new OptNumber(opt.getOptAsNumber("sigmamax"))); 00206 } 00207 00208 T sigmamax = static_cast<T>(nestedOpt->getOptAsNumber("sigmamax")); 00209 00210 // if opt.sigmamin <= 0 00211 if( le(sigmamin, (T)0.0) ) 00212 { 00213 // opt.sigmamin = eps; 00214 nestedOpt->removeOpt("sigmamin"); 00215 nestedOpt->addOpt("sigmamin", new OptNumber(std::numeric_limits<T>::epsilon())); 00216 sigmamin = std::numeric_limits<T>::epsilon(); 00217 } 00218 00219 // if opt.sigmamin <= 0 00220 if( le(sigmamin, (T)0.0) ) 00221 { 00222 // opt.sigmamax = eps; 00223 nestedOpt->removeOpt("sigmamax"); 00224 nestedOpt->addOpt("sigmamax", new OptNumber(std::numeric_limits<T>::epsilon())); 00225 sigmamax = std::numeric_limits<T>::epsilon(); 00226 } 00227 00228 00229 int nlambda = static_cast<int>(opt.getOptAsNumber("nlambda")); 00230 int nsigma = static_cast<int>(opt.getOptAsNumber("nsigma")); 00231 00232 T q = pow( sigmamax/sigmamin, static_cast<T>(1.0/(nsigma-1.0))); 00233 00234 // PERF = zeros(opt.nsigma,opt.nlambda,T); 00235 T* perf = new T[nlambda]; 00236 00237 // sigmas = zeros(1,opt.nsigma); 00238 00239 KernelRBF<T> rbfkernel; 00240 ParamSelHoDual<T> hodual; 00241 00242 00243 T maxPerf = (T)-1.0; 00244 int m = -1; 00245 T guess = (T)-1.0; 00246 00247 T* perf_median = new T[nlambda*t]; 00248 T* guesses_median = new T[nlambda]; 00249 T* row = new T[t]; 00250 00251 const unsigned long nholdouts = static_cast<unsigned long>(opt.getOptAsNumber("nholdouts")); 00252 T* work = new T[nholdouts]; 00253 00254 // for i = 1:opt.nsigma 00255 for(int i=0; i<nsigma; ++i) 00256 { 00257 nestedOpt->addOpt("paramsel", paramsel); 00258 00259 paramsel->removeOpt("sigma"); 00260 paramsel->addOpt("sigma", new OptNumber( sigmamin * pow(q, i))); 00261 00262 // opt.kernel = kernel_rbf(X,y,opt); 00263 GurlsOptionsList* retKernel = rbfkernel.execute(X, Y, *nestedOpt); 00264 00265 nestedOpt->removeOpt("kernel"); 00266 nestedOpt->addOpt("kernel", retKernel); 00267 00268 nestedOpt->removeOpt("paramsel", false); 00269 00270 // paramsel = paramsel_hodual(X,y,opt); 00271 GurlsOptionsList* ret_paramsel = hodual.execute(X, Y, *nestedOpt); 00272 00273 00274 // PERF(i,:,:) = reshape(median(reshape(cell2mat(paramsel.perf')',opt.nlambda*T,nh),2),T,opt.nlambda)'; 00275 gMat2D<T> &perf_mat = ret_paramsel->getOptValue<OptMatrix<gMat2D<T> > >("perf"); // nholdouts x nlambda*t 00276 median(perf_mat.getData(), perf_mat.rows(), perf_mat.cols(), 1, perf_median, work); 00277 00278 for(int j=0;j<nlambda;++j) 00279 { 00280 getRow(perf_median, nlambda, t, j, row); 00281 perf[j] = sumv(row, t); 00282 } 00283 00284 // guesses(i,:) = median(cell2mat(paramsel.guesses'),1); 00285 unsigned long mm = std::max_element(perf, perf + nlambda) - perf; 00286 00287 if( gt(perf[mm], maxPerf) ) 00288 { 00289 maxPerf = perf[mm]; 00290 m = i; 00291 00292 gMat2D<T> &guesses_mat = ret_paramsel->getOptValue<OptMatrix<gMat2D<T> > >("guesses"); // nholdouts x nlambda 00293 median(guesses_mat.getData(), guesses_mat.rows(), guesses_mat.cols(), 1, guesses_median, work); 00294 00295 guess = guesses_median[mm]; 00296 } 00297 00298 delete ret_paramsel; 00299 } 00300 00301 delete [] row; 00302 delete [] work; 00303 delete [] perf; 00304 delete [] perf_median; 00305 delete [] guesses_median; 00306 delete nestedOpt; 00307 00308 00309 paramsel->removeOpt("sigma"); 00310 paramsel->addOpt("sigma", new OptNumber( sigmamin * pow(q,m) )); 00311 00312 // % opt lambda 00313 // vout.lambdas = guesses(m,n)*ones(1,T); 00314 00315 gMat2D<T> *LAMBDA = new gMat2D<T>(1, t); 00316 set(LAMBDA->getData(), guess, t); 00317 00318 paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*LAMBDA)); 00319 00320 return paramsel; 00321 00322 } 00323 00324 00325 } 00326 00327 #endif // _GURLS_SIGLAMHO_H_