![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_HOPRIMAL_H_ 00044 #define _GURLS_HOPRIMAL_H_ 00045 00046 #include <cstdio> 00047 #include <cstring> 00048 #include <iostream> 00049 #include <cmath> 00050 #include <algorithm> 00051 #include <set> 00052 00053 #include "gurls++/options.h" 00054 #include "gurls++/optlist.h" 00055 #include "gurls++/gmat2d.h" 00056 #include "gurls++/gvec.h" 00057 #include "gurls++/gmath.h" 00058 00059 #include "gurls++/paramsel.h" 00060 #include "gurls++/perf.h" 00061 #include "gurls++/dual.h" 00062 #include "gurls++/utils.h" 00063 00064 namespace gurls { 00065 00071 template <typename T> 00072 class ParamSelHoPrimal: public ParamSelection<T>{ 00073 00074 public: 00092 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00093 00094 protected: 00098 unsigned long eig_function(T* A, T* L, int A_rows_cols, unsigned long d, const GurlsOptionsList &opt, unsigned long last); 00099 }; 00100 00101 00107 template <typename T> 00108 class ParamSelHoPrimalr: public ParamSelHoPrimal<T>{ 00109 00110 protected: 00114 unsigned long eig_function(T* A, T* L, int A_rows_cols, unsigned long d, const GurlsOptionsList &opt, unsigned long last); 00115 }; 00116 00117 00118 00119 00120 00121 template<typename T> 00122 unsigned long ParamSelHoPrimal<T>::eig_function(T* A, T* L, int A_rows_cols,unsigned long d, const GurlsOptionsList &, unsigned long last) 00123 { 00124 eig_sm(A, L, A_rows_cols); 00125 00126 return std::min(d,last); 00127 } 00128 00129 template<typename T> 00130 unsigned long ParamSelHoPrimalr<T>::eig_function(T* A, T* L, int A_rows_cols, unsigned long d, const GurlsOptionsList &opt, unsigned long ) 00131 { 00132 T* V = NULL; 00133 T k = gurls::round((opt.getOptAsNumber("eig_percentage")*d)/100.0); 00134 random_svd(A, A_rows_cols, A_rows_cols, A, L, V, A_rows_cols, k); 00135 00136 return static_cast<unsigned long>(k); 00137 } 00138 00139 template <typename T> 00140 GurlsOptionsList *ParamSelHoPrimal<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt) 00141 { 00142 // [n,T] = size(y); 00143 const unsigned long t = Y.cols(); 00144 const unsigned long d = X.cols(); 00145 00146 00147 GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested"); 00148 00149 00150 const GurlsOptionsList* split = opt.getOptAs<GurlsOptionsList>("split"); 00151 const gMat2D< unsigned long > &indices_mat = split->getOptValue<OptMatrix<gMat2D< unsigned long > > >("indices"); 00152 const gMat2D< unsigned long > &lasts_mat = split->getOptValue<OptMatrix<gMat2D< unsigned long > > >("lasts"); 00153 00154 const unsigned long n = indices_mat.cols(); 00155 00156 const unsigned long *lasts = lasts_mat.getData(); 00157 const unsigned long* indices_buffer = indices_mat.getData(); 00158 00159 00160 int tot = static_cast<int>(std::ceil( opt.getOptAsNumber("nlambda"))); 00161 00162 int nholdouts = static_cast<int>(std::ceil( opt.getOptAsNumber("nholdouts"))); 00163 00164 gMat2D<T> *LAMBDA = new gMat2D<T>(1, t); 00165 T* lambdas = LAMBDA->getData(); 00166 set(lambdas, (T)0.0, t); 00167 00168 00169 T* Q = new T[d*d]; 00170 T* QtXty = new T[d*t]; 00171 T *L = new T[d]; 00172 00173 gMat2D<T>* perf_mat = new gMat2D<T>(nholdouts, tot*t); 00174 T* perf = perf_mat->getData(); 00175 00176 gMat2D<T>* guesses_mat = new gMat2D<T>(nholdouts, tot); 00177 T* ret_guesses = guesses_mat->getData(); 00178 00179 gMat2D<T>* lambdas_round_mat = new gMat2D<T>(nholdouts, t); 00180 T* lambdas_round = lambdas_round_mat->getData(); 00181 00182 PredPrimal< T > primal; 00183 Performance<T>* perfClass = Performance<T>::factory(opt.getOptAsString("hoperf")); 00184 00185 GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer"); 00186 nestedOpt->addOpt("optimizer",optimizer); 00187 00188 gMat2D<T> *W = new gMat2D<T>(d, t); 00189 optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W)); 00190 00191 00192 bool hasXt = opt.hasOpt("kernel.XtX") && opt.hasOpt("kernel.Xty"); 00193 00194 for(int nh=0; nh<nholdouts; ++nh) 00195 { 00196 unsigned long last = lasts[nh]; 00197 unsigned long* tr = new unsigned long[last]; 00198 unsigned long* va = new unsigned long[n-last]; 00199 00200 //copy int tr indices_ from n*nh to last 00201 copy< unsigned long >(tr,indices_buffer + n*nh,last,1,1); 00202 00203 //copy int va indices_ from n*nh+last to n*nh+n 00204 copy< unsigned long >(va,(indices_buffer+ n*nh+last), n-last,1,1); 00205 00206 00207 gMat2D<T> Xva(n-last, d); 00208 gMat2D<T> yva(n-last, t); 00209 00210 subMatrixFromRows(X.getData(), X.rows(), d, va, n-last, Xva.getData()); 00211 subMatrixFromRows(Y.getData(), Y.rows(), t, va, n-last, yva.getData()); 00212 00213 T* Xtr = NULL; 00214 if(hasXt) 00215 { 00216 T* XvatXva = new T[d*d]; 00217 dot(Xva.getData(), Xva.getData(), XvatXva, n-last, d, n-last, d, d, d, CblasTrans, CblasNoTrans, CblasColMajor); 00218 00219 const gMat2D<T>&XtX = opt.getOptValue<OptMatrix<gMat2D<T> > >("kernel.XtX"); 00220 00221 // Q = XtX - XvatXva 00222 copy(Q, XtX.getData(), XtX.getSize()); 00223 axpy(d*d, (T)-1.0, XvatXva, 1, Q, 1); 00224 00225 delete [] XvatXva; 00226 } 00227 else 00228 { 00229 // K = X(tr,:)'*X(tr,:); 00230 Xtr = new T[last*d]; 00231 subMatrixFromRows(X.getData(), n, d, tr, last, Xtr); 00232 00233 dot(Xtr, Xtr, Q, last, d, last, d, d, d, CblasTrans, CblasNoTrans, CblasColMajor); 00234 } 00235 00236 unsigned long k = eig_function(Q, L, d, d, opt, last); 00237 00238 T* guesses = lambdaguesses(L, d, k, last, tot, (T)(opt.getOptAsNumber("smallnumber"))); 00239 00240 T* ap = new T[tot*t]; 00241 00242 00243 if(hasXt) 00244 { 00245 T* Xvatyva = new T[d*t]; 00246 dot(Xva.getData(), yva.getData(), Xvatyva, n-last, d, n-last, t, d, t, CblasTrans, CblasNoTrans, CblasColMajor); 00247 00248 const gMat2D<T>&Xty = opt.getOptValue<OptMatrix<gMat2D<T> > >("kernel.Xty"); 00249 00250 00251 // QtXty = Q'*(Xty - XvatXva) 00252 00253 T* Xtrtytr = new T[d*t]; 00254 00255 copy(Xtrtytr, Xty.getData(), Xty.getSize()); 00256 axpy(d*t, (T)-1.0, Xvatyva, 1, Xtrtytr, 1); 00257 00258 dot(Q, Xtrtytr, QtXty, d, d, d, t, d, t, CblasTrans, CblasNoTrans, CblasColMajor); 00259 00260 delete [] Xvatyva; 00261 delete [] Xtrtytr; 00262 } 00263 else 00264 { 00265 T* ytr = new T[last*t]; 00266 subMatrixFromRows(Y.getData(), n, t, tr, last, ytr); 00267 00268 00269 T* Xtrtytr = new T[d*t]; 00270 dot(Xtr, ytr, Xtrtytr, last, d, last, t, d, t, CblasTrans, CblasNoTrans, CblasColMajor); 00271 delete [] Xtr; 00272 00273 dot(Q, Xtrtytr, QtXty, d, d, d, t, d, t, CblasTrans, CblasNoTrans, CblasColMajor); 00274 00275 delete [] ytr; 00276 delete [] Xtrtytr; 00277 } 00278 00279 00280 T* work = new T[d*(d+1)]; 00281 00282 for(int i=0; i<tot; ++i) 00283 { 00284 rls_eigen(Q, L, QtXty, W->getData(), guesses[i], last, d, d, d, d, t, work); 00285 00286 OptMatrix<gMat2D<T> > *ret_pred = primal.execute(Xva, yva, *nestedOpt); 00287 00288 nestedOpt->removeOpt("pred"); 00289 nestedOpt->addOpt("pred", ret_pred); 00290 00291 GurlsOptionsList* ret_perf = perfClass->execute(Xva, yva, *nestedOpt); 00292 00293 gMat2D<T> &forho_vec = ret_perf->getOptValue<OptMatrix<gMat2D<T> > >("forho"); 00294 00295 copy(ap+i, forho_vec.getData(), t, tot, 1); 00296 00297 delete ret_perf; 00298 } 00299 00300 delete [] va; 00301 delete [] tr; 00302 delete [] work; 00303 00304 //[dummy,idx] = max(ap,[],1); 00305 work = NULL; 00306 unsigned long* idx = new unsigned long[t]; 00307 indicesOfMax(ap, tot, t, idx, work, 1); 00308 00309 //vout.lambdas_round{nh} = guesses(idx); 00310 T* lambdas_nh = new T[t]; 00311 copyLocations(idx, guesses, t, tot, lambdas_nh); 00312 00313 copy(lambdas_round+nh, lambdas_nh, t, nholdouts, 1); 00314 00315 //add lambdas_nh to lambdas 00316 axpy(t, (T)1, lambdas_nh, 1, lambdas, 1); 00317 00318 delete [] lambdas_nh; 00319 delete [] idx; 00320 00321 // vout.perf{nh} = ap; 00322 copy(perf + nh, ap, tot*t, nholdouts, 1); 00323 00324 // vout.guesses{nh} = guesses; 00325 copy(ret_guesses + nh, guesses, tot, nholdouts, 1); 00326 00327 delete [] guesses; 00328 delete [] ap; 00329 00330 } 00331 00332 delete nestedOpt; 00333 00334 delete perfClass; 00335 delete [] Q; 00336 delete [] QtXty; 00337 delete [] L; 00338 00339 GurlsOptionsList* paramsel; 00340 00341 if(opt.hasOpt("paramsel")) 00342 { 00343 GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp"); 00344 tmp_opt->copyOpt("paramsel", opt); 00345 00346 paramsel = GurlsOptionsList::dynacast(tmp_opt->getOpt("paramsel")); 00347 tmp_opt->removeOpt("paramsel", false); 00348 delete tmp_opt; 00349 00350 paramsel->removeOpt("guesses"); 00351 paramsel->removeOpt("lambdas"); 00352 paramsel->removeOpt("perf"); 00353 paramsel->removeOpt("lambdas_round"); 00354 } 00355 else 00356 paramsel = new GurlsOptionsList("paramsel"); 00357 00358 00359 paramsel->addOpt("guesses", new OptMatrix<gMat2D<T> >(*guesses_mat)); 00360 00361 if(nholdouts>1) 00362 scal(t, (T)1.0/nholdouts, lambdas, 1); 00363 00364 paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*LAMBDA)); 00365 paramsel->addOpt("perf", new OptMatrix<gMat2D<T> >(*perf_mat)); 00366 paramsel->addOpt("lambdas_round", new OptMatrix<gMat2D<T> >(*lambdas_round_mat)); 00367 00368 return paramsel; 00369 } 00370 00371 } 00372 00373 #endif // _GURLS_HOPRIMAL_H_