![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_SIGLAMLOOGPREGR_H_ 00044 #define _GURLS_SIGLAMLOOGPREGR_H_ 00045 00046 #include <cmath> 00047 00048 #include "gurls++/options.h" 00049 #include "gurls++/optlist.h" 00050 #include "gurls++/gmat2d.h" 00051 #include "gurls++/gvec.h" 00052 #include "gurls++/gmath.h" 00053 00054 #include "gurls++/paramsel.h" 00055 #include "gurls++/loogpregr.h" 00056 #include "gurls++/rbfkernel.h" 00057 00058 namespace gurls { 00059 00065 template <typename T> 00066 class ParamSelSiglamLooGPRegr: public ParamSelection<T>{ 00067 00068 public: 00087 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00088 }; 00089 00090 template <typename T> 00091 GurlsOptionsList *ParamSelSiglamLooGPRegr<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) 00092 { 00093 // [n,T] = size(y); 00094 const unsigned long n = Y.rows(); 00095 const unsigned long t = Y.cols(); 00096 00097 const unsigned long d = X.cols(); 00098 00099 GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested"); 00100 nestedOpt->copyOpt("nlambda", opt); 00101 nestedOpt->copyOpt("hoperf", opt); 00102 nestedOpt->copyOpt("singlelambda", opt); 00103 00104 // if ~isfield(opt,'kernel') 00105 if(!opt.hasOpt("kernel")) 00106 { 00107 // opt.kernel.type = 'rbf'; 00108 GurlsOptionsList* kernel = new GurlsOptionsList("kernel"); 00109 kernel->addOpt("type", "rbf"); 00110 00111 nestedOpt->addOpt("kernel", kernel); 00112 } 00113 else 00114 nestedOpt->copyOpt("kernel", opt); 00115 00116 00117 GurlsOptionsList* kernel = nestedOpt->getOptAs<GurlsOptionsList>("kernel"); 00118 00119 00120 gMat2D<T> *distance; 00121 00122 // if ~isfield(opt.kernel,'distance') 00123 if(!kernel->hasOpt("distance")) 00124 { 00125 distance = new gMat2D<T>(n, n); 00126 00127 // opt.kernel.distance = square_distance(X',X'); 00128 distance_transposed(X.getData(), X.getData(), d, n, n, distance->getData()); 00129 00130 kernel->addOpt("distance", new OptMatrix<gMat2D<T> >(*distance)); 00131 } 00132 else 00133 distance = &(kernel->getOptValue<OptMatrix<gMat2D<T> > >("distance")); 00134 00135 00136 // if ~isfield(opt,'sigmamin') 00137 if(!opt.hasOpt("sigmamin")) 00138 { 00139 int d_len = n*(n-1)/2; 00140 T* distLinearized = new T[d_len]; 00141 00142 // D = sort(opt.kernel.distance(tril(true(n),-1))); 00143 00144 const unsigned long size = distance->cols(); 00145 T* it = distLinearized; 00146 T* d_it = distance->getData(); 00147 00148 for(unsigned long i=1; i< size; ++i) 00149 { 00150 gurls::copy(it , d_it+i, size - i); 00151 00152 it += size - i; 00153 d_it += size; 00154 } 00155 00156 std::sort(distLinearized, distLinearized + d_len); 00157 00158 // firstPercentile = round(0.01*numel(D)+0.5); 00159 int firstPercentile = gurls::round((T)0.01 * d_len + (T)0.5)-1; 00160 00161 // opt.sigmamin = D(firstPercentile); 00162 nestedOpt->addOpt("sigmamin", new OptNumber(sqrt( distLinearized[firstPercentile]) )); 00163 00164 delete [] distLinearized; 00165 } 00166 else 00167 nestedOpt->copyOpt("sigmamin", opt); 00168 00169 00170 // if ~isfield(opt,'sigmamax') 00171 if(!opt.hasOpt("sigmamax")) 00172 { 00173 // opt.sigmamax = sqrt(max(max(opt.kernel.distance))); 00174 double sigmaMax = sqrt(*std::max_element(distance->getData(), distance->getData()+distance->getSize())); 00175 nestedOpt->addOpt("sigmamax", new OptNumber(sigmaMax)); 00176 } 00177 else 00178 nestedOpt->copyOpt("sigmamax", opt); 00179 00180 00181 // if opt.sigmamin <= 0 00182 if( le(nestedOpt->getOptAsNumber("sigmamin"), 0.0) ) 00183 { 00184 // opt.sigmamin = eps; 00185 nestedOpt->removeOpt("sigmamin"); 00186 nestedOpt->addOpt("sigmamin", new OptNumber(std::numeric_limits<T>::epsilon())); 00187 } 00188 // if opt.sigmamax <= 0 00189 if( le(nestedOpt->getOptAsNumber("sigmamax"), 0.0) ) 00190 { 00191 // opt.sigmamax = eps; 00192 nestedOpt->removeOpt("sigmamax"); 00193 nestedOpt->addOpt("sigmamax", new OptNumber(std::numeric_limits<T>::epsilon())); 00194 } 00195 00196 const int nsigma = static_cast<int>(opt.getOptAsNumber("nsigma")); 00197 const int nlambda = static_cast<int>(opt.getOptAsNumber("nlambda")); 00198 const T sigmamin = static_cast<T>(nestedOpt->getOptAsNumber("sigmamin")); 00199 00200 // q = (opt.sigmamax/opt.sigmamin)^(1/(opt.nsigma-1)); 00201 T q = static_cast<T>(std::pow(nestedOpt->getOptAsNumber("sigmamax")/nestedOpt->getOptAsNumber("sigmamin"), 1.0/(nsigma-1.0))); 00202 00203 // perf = zeros(opt.nsigma,opt.nlambda,T); 00204 T* perf = new T[nsigma*nlambda]; 00205 set(perf, (T)0.0, nsigma*nlambda); 00206 00207 // sigmas = zeros(1,opt.nsigma); 00208 // T* sigmas = new T[nsigma]; 00209 00210 T* guesses = new T[nsigma*nlambda]; 00211 00212 KernelRBF<T> rbf; 00213 ParamSelLooGPRegr<T> loogp; 00214 00215 T* work = new T[t]; 00216 00217 GurlsOptionsList* paramsel_rbf = new GurlsOptionsList("paramsel"); 00218 nestedOpt->addOpt("paramsel", paramsel_rbf); 00219 00220 // for i = 1:opt.nsigma 00221 for(int i=0; i<nsigma; ++i) 00222 { 00223 // sigmas(i) = (opt.sigmamin*(q^(i-1))); 00224 const T sigma = sigmamin* std::pow(q, i); 00225 // sigmas[i] = sigma; 00226 00227 // opt.paramsel.sigma = sigmas(i); 00228 paramsel_rbf->removeOpt("sigma"); 00229 paramsel_rbf->addOpt("sigma", new OptNumber(sigma)); 00230 00231 00232 // opt.kernel = kernel_rbf(X,y,opt); 00233 GurlsOptionsList* rbf_kernel = rbf.execute(X, Y, *nestedOpt); 00234 00235 nestedOpt->removeOpt("kernel"); 00236 nestedOpt->addOpt("kernel", rbf_kernel); 00237 00238 // paramsel = paramsel_loogpregr(X,y,opt); 00239 GurlsOptionsList* paramsel_loogp = loogp.execute(X, Y, *nestedOpt); 00240 00241 // perf(i,:,:) = paramsel.perf; 00242 gMat2D<T> &perf_mat = paramsel_loogp->getOptValue<OptMatrix<gMat2D<T> > >("perf"); // nlambda x t 00243 00244 T* perf_it = perf + i; 00245 for(int j=0; j<nlambda; ++j, perf_it += nsigma) 00246 { 00247 getRow(perf_mat.getData(), perf_mat.rows(), perf_mat.cols(), j, work); 00248 00249 *perf_it = sumv(work, t); 00250 } 00251 00252 // guesses(i,:) = paramsel.guesses; 00253 gMat2D<T> &guesses_mat = paramsel_loogp->getOptValue<OptMatrix<gMat2D<T> > >("guesses"); // nlambda x 1 00254 copy(guesses+i, guesses_mat.getData(), nlambda, nsigma, 1); 00255 00256 delete paramsel_loogp; 00257 } 00258 00259 delete [] work; 00260 delete nestedOpt; 00261 00262 // M = sum(perf,3); % sum over classes 00263 // [dummy,i] = max(M(:)); 00264 int i = std::max_element(perf, perf +(nsigma*nlambda)) - perf; 00265 00266 // [m,n] = ind2sub(size(M),i); 00267 int im = i%nsigma; 00268 00269 delete [] perf; 00270 00271 GurlsOptionsList* paramsel = new GurlsOptionsList("paramsel"); 00272 00273 // % opt sigma 00274 // vout.sigma = opt.sigmamin*(q^(m-1)); 00275 paramsel->addOpt("sigma", new OptNumber( sigmamin*(std::pow(q, im))) ); 00276 00277 // % opt lambda 00278 // vout.noises = guesses(m,n)*ones(1,T); 00279 00280 gMat2D<T> *lambdas = new gMat2D<T>(1, t); 00281 set(lambdas->getData(), guesses[i], t); 00282 00283 paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*lambdas)); 00284 00285 delete [] guesses; 00286 00287 return paramsel; 00288 } 00289 00290 00291 } 00292 00293 #endif // _GURLS_SIGLAMLOOGPREGR_H_