![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_LOOGPREGR_H_ 00044 #define _GURLS_LOOGPREGR_H_ 00045 00046 #include <cmath> 00047 00048 #include "gurls++/options.h" 00049 #include "gurls++/optlist.h" 00050 #include "gurls++/gmat2d.h" 00051 #include "gurls++/gvec.h" 00052 #include "gurls++/gmath.h" 00053 00054 #include "gurls++/paramsel.h" 00055 #include "gurls++/perf.h" 00056 00057 #include "gurls++/rlsgp.h" 00058 #include "gurls++/predgp.h" 00059 00060 namespace gurls { 00061 00067 template <typename T> 00068 class ParamSelLooGPRegr: public ParamSelection<T>{ 00069 00070 public: 00087 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00088 }; 00089 00090 template <typename T> 00091 GurlsOptionsList* ParamSelLooGPRegr<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) 00092 { 00093 // [n,T] = size(y); 00094 const unsigned long n = Y.rows(); 00095 const unsigned long t = Y.cols(); 00096 00097 const unsigned long d = X.cols(); 00098 00099 // tot = opt.nlambda; 00100 int tot = static_cast<int>(opt.getOptAsNumber("nlambda")); 00101 00102 // K = opt.kernel.K; 00103 const gMat2D<T> &K = opt.getOptValue<OptMatrix<gMat2D<T> > >("kernel.K"); 00104 00105 00106 // lmax = mean(std(y)); 00107 00108 // T* work = new T[t+n]; 00109 // T* stdY = new T[t]; 00110 00111 // stdDev(Y.getData(), n, t, stdY, work); 00112 00113 // const T lmax = sumv(stdY, t)/((T)t); 00114 00115 // delete[] work; 00116 // delete[] stdY; 00117 00119 // const T lmin = lmax * (T)1.0e-5; 00120 00121 T lmin; 00122 T lmax; 00123 00124 if(opt.hasOpt("lambdamin")) 00125 lmin = opt.getOptAsNumber("lambdamin"); 00126 else 00127 lmin = 0.001; 00128 00129 if(opt.hasOpt("lambdamax")) 00130 lmax = opt.getOptAsNumber("lambdamax"); 00131 else 00132 lmax = 10; 00133 00134 // guesses = lmin.*(lmax/lmin).^linspace(0,1,tot); 00135 gMat2D<T> *guesses_mat = new gMat2D<T>(tot, 1); 00136 T* guesses = guesses_mat->getData(); 00137 00138 T* linspc = new T[tot]; 00139 linspace((T)0.0, (T)1.0, tot, linspc); 00140 const T coeff = lmax/lmin; 00141 00142 for(int i=0; i< tot; ++i) 00143 guesses[i] = lmin* std::pow(coeff, linspc[i]); 00144 00145 delete[] linspc; 00146 00147 00148 // perf = zeros(tot,T); 00149 gMat2D<T> *perf_mat = new gMat2D<T>(tot, t); 00150 T* perf = perf_mat->getData(); 00151 set(perf, (T)0.0, tot*t); 00152 00153 const int tr_size = n-1; 00154 00155 unsigned long* tr = new unsigned long[tr_size+1]; // + 1 cell for convenience 00156 unsigned long* tr_it = tr; 00157 for(unsigned long i=1; i< n; ++i, ++tr_it) 00158 *tr_it = i; 00159 00160 GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested"); 00161 nestedOpt->copyOpt("singlelambda", opt); 00162 00163 00164 gMat2D<T>* tmpK = new gMat2D<T>(tr_size, tr_size); 00165 gMat2D<T>* tmpPredK = new gMat2D<T>(1, tr_size); 00166 gMat2D<T>* tmpPredKTest = new gMat2D<T>(1, 1); 00167 00168 GurlsOptionsList* tmpPredKernel = new GurlsOptionsList("predkernel"); 00169 GurlsOptionsList* tmpKernel = new GurlsOptionsList("kernel"); 00170 GurlsOptionsList* tmpParamSel = new GurlsOptionsList("paramsel"); 00171 00172 nestedOpt->addOpt("kernel", tmpKernel); 00173 nestedOpt->addOpt("predkernel", tmpPredKernel); 00174 nestedOpt->addOpt("paramsel", tmpParamSel); 00175 00176 tmpKernel->addOpt("K", new OptMatrix<gMat2D<T> > (*tmpK)); 00177 tmpPredKernel->addOpt("K", new OptMatrix<gMat2D<T> > (*tmpPredK)); 00178 tmpPredKernel->addOpt("Ktest", new OptMatrix<gMat2D<T> > (*tmpPredKTest)); 00179 00180 gMat2D<T> rlsX(tr_size, d); 00181 gMat2D<T> rlsY(tr_size, t); 00182 00183 // T* tmpMat = new T[ tr_size * std::max(d, t)]; 00184 00185 gMat2D<T> predX(1, d); 00186 gMat2D<T> predY(1, t); 00187 00188 RLSGPRegr<T> rlsgp; 00189 PredGPRegr<T> predgp; 00190 Performance<T>* perfClass = Performance<T>::factory(opt.getOptAsString("hoperf")); 00191 00192 gMat2D<T> *lambda = new gMat2D<T>(1,1); 00193 tmpParamSel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*lambda)); 00194 00195 // for k = 1:n; 00196 for(unsigned long k = 0; k<n; ++k) 00197 { 00198 // tr = setdiff(1:n,k); 00199 00200 // opt.kernel.K = K(tr,tr); 00201 copy_submatrix(tmpK->getData(), K.getData(), K.rows(), tr_size, tr_size, tr, tr); 00202 00203 // opt.predkernel.K = K(k,tr); 00204 copy_submatrix(tmpPredK->getData(), K.getData(), K.rows(), 1, tr_size, &k , tr); 00205 00206 // opt.predkernel.Ktest = K(k,k); 00207 tmpPredKTest->getData()[0] = K.getData()[(k*K.rows()) + k]; 00208 00209 // for i = 1:tot 00210 for(int i=0; i< tot; ++i) 00211 { 00212 // opt.paramsel.noises = guesses(i); 00213 lambda->getData()[0] = guesses[i]; 00214 00215 // opt.rls = rls_gpregr(X(tr,:),y(tr,:),opt); 00216 subMatrixFromRows(X.getData(), n, d, tr, tr_size, rlsX.getData()); 00217 00218 subMatrixFromRows(Y.getData(), n, t, tr, tr_size, rlsY.getData()); 00219 00220 GurlsOptionsList* ret_rlsgp = rlsgp.execute(rlsX, rlsY, *nestedOpt); 00221 00222 nestedOpt->removeOpt("optimizer"); 00223 nestedOpt->addOpt("optimizer", ret_rlsgp); 00224 00225 // tmp = pred_gpregr(X(k,:),y(k,:),opt); 00226 getRow(X.getData(), n, d, k, predX.getData()); 00227 getRow(Y.getData(), n, t, k, predY.getData()); 00228 00229 GurlsOptionsList * pred_list = predgp.execute(predX, predY, *nestedOpt); 00230 00231 // opt.pred = tmp.means; 00232 nestedOpt->removeOpt("pred"); 00233 nestedOpt->addOpt("pred", pred_list->getOpt("means")); 00234 00235 pred_list->removeOpt("means", false); 00236 00237 delete pred_list; 00238 00239 // opt.perf = opt.hoperf([],y(k,:),opt); 00240 GurlsOptionsList * perf_list = perfClass->execute(predX, predY, *nestedOpt); 00241 00242 gMat2D<T>& forho = perf_list->getOptValue<OptMatrix<gMat2D<T> > >("forho"); 00243 00244 // for t = 1:T 00245 for(unsigned long j = 0; j<t; ++j) 00246 // perf(i,t) = opt.perf.forho(t)+perf(i,t)./n; 00247 // perf(i,t) = opt.perf.forho(t)./n+perf(i,t); 00248 perf[i+(tot*j)] += forho.getData()[j]/n; 00249 //perf[i+(tot*j)] = forho.getData()[j]/n+perf[i+(tot*j)]; 00250 00251 00252 delete perf_list; 00253 } 00254 00255 tr[k] = k; 00256 00257 } 00258 00259 delete perfClass; 00260 00261 delete nestedOpt; 00262 00263 00264 GurlsOptionsList* paramsel; 00265 00266 if(opt.hasOpt("paramsel")) 00267 { 00268 GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp"); 00269 tmp_opt->copyOpt("paramsel", opt); 00270 00271 paramsel = tmp_opt->getOptAs<GurlsOptionsList>("paramsel"); 00272 tmp_opt->removeOpt("paramsel", false); 00273 delete tmp_opt; 00274 00275 paramsel->removeOpt("lambdas"); 00276 paramsel->removeOpt("perf"); 00277 paramsel->removeOpt("guesses"); 00278 } 00279 else 00280 paramsel = new GurlsOptionsList("paramsel"); 00281 00282 00283 // [dummy,idx] = max(perf,[],1); 00284 unsigned long* idx = new unsigned long[t]; 00285 T* work = NULL; 00286 indicesOfMax(perf, tot, t, idx, work, 1); 00287 00288 00289 // vout.noises = guesses(idx); 00290 gMat2D<T> *lambdas = new gMat2D<T>(1, t); 00291 copyLocations(idx, guesses, t, tot, lambdas->getData()); 00292 00293 delete[] idx; 00294 00295 paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*lambdas)); 00296 00297 // vout.perf = perf; 00298 paramsel->addOpt("perf", new OptMatrix<gMat2D<T> >(*perf_mat)); 00299 00300 // vout.guesses = guesses; 00301 paramsel->addOpt("guesses", new OptMatrix<gMat2D<T> >(*guesses_mat)); 00302 00303 return paramsel; 00304 } 00305 00306 00307 } 00308 00309 #endif // _GURLS_LOOGPREGR_H_