![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-2013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 00044 #ifndef _GURLS_LOOCVPRIMAL_H_ 00045 #define _GURLS_LOOCVPRIMAL_H_ 00046 00047 #include <cstdio> 00048 #include <cstring> 00049 #include <iostream> 00050 #include <cmath> 00051 #include <algorithm> 00052 #include <set> 00053 00054 #include "gurls++/options.h" 00055 #include "gurls++/optlist.h" 00056 #include "gurls++/optmatrix.h" 00057 00058 #include "gurls++/gmat2d.h" 00059 #include "gurls++/gvec.h" 00060 #include "gurls++/gmath.h" 00061 00062 #include "gurls++/paramsel.h" 00063 00064 #include "gurls++/precisionrecall.h" 00065 #include "gurls++/macroavg.h" 00066 #include "gurls++/rmse.h" 00067 00068 namespace gurls { 00069 00075 template <typename T> 00076 class ParamSelLoocvPrimal: public ParamSelection<T> 00077 { 00078 00079 public: 00093 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00094 }; 00095 00096 template <typename T> 00097 GurlsOptionsList *ParamSelLoocvPrimal<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt) 00098 { 00099 typename std::set<T*> garbage; 00100 00101 try 00102 { 00103 //[n,T] = size(y); 00104 const unsigned long n = Y.rows(); 00105 const unsigned long t = Y.cols(); 00106 00107 00108 // K = X'*X; 00109 const unsigned long xc = X.cols(); 00110 const unsigned long xr = X.rows(); 00111 00112 if(xr != n) 00113 throw gException(Exception_Inconsistent_Size); 00114 00115 T* K = new T[xc*xc]; 00116 garbage.insert(K); 00117 dot(X.getData(), X.getData(), K, xr, xc, xr, xc, xc, xc, CblasTrans, CblasNoTrans, CblasColMajor); 00118 00119 00120 // tot = opt.nlambda; 00121 int tot = static_cast<int>(std::ceil( opt.getOptAsNumber("nlambda"))); 00122 00123 // [Q,L] = eig(K); 00124 // L = diag(L); 00125 00126 // T* Q, *L; 00127 00128 // eig(K, Q, L, xc, xc); 00129 // garbage.insert(Q); 00130 // garbage.insert(L); 00131 00132 T* Q = K; 00133 T* L = new T[xc]; 00134 garbage.insert(L); 00135 00136 eig_sm(Q, L, xc); 00137 00138 00139 T* filtered = L; 00140 T* guesses = lambdaguesses(filtered, xc, std::min(xc,xr), xr, tot, (T)(opt.getOptAsNumber("smallnumber"))); 00141 garbage.insert(guesses); 00142 // set(guesses, 0.f, tot); 00143 00144 T* LOOSQE = new T[tot*t]; 00145 garbage.insert(LOOSQE); 00146 set(LOOSQE, (T)0.0, tot*t); 00147 00148 // LEFT = X*Q; 00149 // RIGHT = Q'*X'*y; 00150 T* LEFT = new T[xr*xc]; 00151 garbage.insert(LEFT); 00152 dot(X.getData(), Q, LEFT, xr, xc, xc, xc, xr, xc, CblasNoTrans, CblasNoTrans, CblasColMajor); 00153 00154 T* tmp = new T[xc*t]; 00155 garbage.insert(tmp); 00156 dot(X.getData(), Y.getData(), tmp, xr, xc, n, t, xc, t, CblasTrans, CblasNoTrans, CblasColMajor); 00157 00158 T* RIGHT = new T[xc*t]; 00159 garbage.insert(RIGHT); 00160 dot(Q, tmp, RIGHT, xc, xc, xc, t, xc, t, CblasTrans, CblasNoTrans, CblasColMajor); 00161 00162 delete[] tmp; 00163 garbage.erase(tmp); 00164 00165 // right = Q'*X'; 00166 T* right = new T[xc*xr]; 00167 garbage.insert(right); 00168 dot(Q, X.getData(), right, xc, xc, xr, xc, xc, xr, CblasTrans, CblasTrans, CblasColMajor); 00169 00170 delete[] Q; 00171 garbage.erase(Q); 00172 00173 T* den = new T[n]; 00174 // T* Le = new T[n*t]; 00175 garbage.insert(den); 00176 // garbage.insert(Le); 00177 00178 T* tmpvec = new T[xc]; 00179 garbage.insert(tmpvec); 00180 T* tmp1 = new T[xr*t]; 00181 garbage.insert(tmp1); 00182 T* LL = new T[xc*xc]; 00183 garbage.insert(LL); 00184 tmp = new T[xc*t]; 00185 garbage.insert(tmp); 00186 T* num = new T[xr*t]; 00187 garbage.insert(num); 00188 T* row = new T[xc]; 00189 garbage.insert(row); 00190 T* num_div_den = new T[n]; 00191 garbage.insert(num_div_den); 00192 00193 GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested"); 00194 00195 gMat2D<T>* pred = new gMat2D<T>(n, t); 00196 OptMatrix<gMat2D<T> >* pred_opt = new OptMatrix<gMat2D<T> >(*pred); 00197 nestedOpt->addOpt("pred", pred_opt); 00198 00199 // const int pred_size = pred->getSize(); 00200 00201 Performance<T>* perfClass = Performance<T>::factory(opt.getOptAsString("hoperf")); 00202 00203 gMat2D<T>* perf = new gMat2D<T>(tot, t); 00204 T* ap = perf->getData(); 00205 00206 00207 // for i = 1:tot 00208 for(int s = 0; s < tot; ++s) 00209 { 00210 00211 // LL = L + (n*guesses(i)); 00212 set(tmpvec, n*guesses[s] , xc); 00213 axpy(xc, (T)1.0, L, 1, tmpvec, 1); 00214 00215 // LL = LL.^(-1) 00216 setReciprocal(tmpvec, xc); 00217 // LL = diag(LL); 00218 diag(tmpvec, xc, LL); 00219 00220 00221 // num = y - LEFT*LL*RIGHT; 00222 00223 dot(LL, RIGHT, tmp, xc, xc, xc, t, xc, t, CblasNoTrans, CblasNoTrans, CblasColMajor); 00224 dot(LEFT, tmp, tmp1, xr, xc, xc, t, xr, t, CblasNoTrans, CblasNoTrans, CblasColMajor); 00225 00226 00227 copy(num, Y.getData(), xr*t); 00228 axpy(xr*t, (T)-1.0, tmp1, 1, num, 1); 00229 00230 00231 // den = zeros(n,1); 00232 set(den, (T)0.0, n); 00233 00234 // for j = 1:n 00235 // den(j) = 1-LEFT(j,:)*LL*right(:,j); 00236 // end 00237 00238 00239 for (unsigned long j = 0; j < n; ++j) 00240 { 00241 dot(LL, right +(xc*j), tmp, xc, xc, xc, 1, xc, 1, CblasNoTrans, CblasNoTrans, CblasColMajor); 00242 00243 //extract j-th row from LEFT 00244 copy(row, LEFT + j, xc, 1, xr); 00245 00246 den[j] = ((T) 1.0) - dot (xc, row, 1, tmp, 1); 00247 } 00248 00249 00250 // opt.pred = zeros(n,T); 00251 // set(pred->getData(), (T)0.0, pred_size); 00252 00253 00254 00255 // for t = 1:T 00256 for(unsigned long j = 0; j< t; ++j) 00257 { 00258 rdivide(num + (n*j), den, num_div_den, n); 00259 00260 // opt.pred(:,t) = y(:,t) - (num(:,t)./den); 00261 copy(pred->getData()+(n*j), Y.getData() + (n*j), n); 00262 axpy(n, (T)-1.0, num_div_den, 1, pred->getData()+(n*j), 1); 00263 } 00264 00265 // opt.perf = opt.hoperf([],y,opt); 00266 const gMat2D<T> dummy; 00267 GurlsOptionsList* perf = perfClass->execute(dummy, Y, *nestedOpt); 00268 00269 gMat2D<T> &forho_vec = perf->getOptValue<OptMatrix<gMat2D<T> > >("forho"); 00270 00271 // for t = 1:T 00272 copy(ap+s, forho_vec.getData(), forho_vec.getSize(), tot, 1); 00273 00274 delete perf; 00275 } 00276 00277 00278 delete perfClass; 00279 00280 delete[] row; 00281 garbage.erase(row); 00282 delete[] num; 00283 garbage.erase(num); 00284 delete[] tmp; 00285 garbage.erase(tmp); 00286 delete[] tmp1; 00287 garbage.erase(tmp1); 00288 delete[] tmpvec; 00289 garbage.erase(tmpvec); 00290 delete[] LL; 00291 garbage.erase(LL); 00292 delete [] num_div_den; 00293 garbage.erase(num_div_den); 00294 00295 delete nestedOpt; 00296 delete[] L; 00297 garbage.erase(L); 00298 delete[] LEFT; 00299 garbage.erase(LEFT); 00300 delete[] RIGHT; 00301 garbage.erase(RIGHT); 00302 delete[] right; 00303 garbage.erase(right); 00304 delete[] den; 00305 garbage.erase(den); 00306 delete [] LOOSQE; 00307 garbage.erase(LOOSQE); 00308 // delete[] Le; 00309 // garbage.erase(Le); 00310 00311 //[dummy,idx] = max(ap,[],1); 00312 unsigned long* idx = new unsigned long[t]; 00313 T* work = NULL; 00314 indicesOfMax(ap, tot, t, idx, work, 1); 00315 00316 //vout.lambdas = guesses(idx); 00317 gMat2D<T> *LAMBDA = new gMat2D<T>(1, t); 00318 copyLocations(idx, guesses, t, tot, LAMBDA->getData()); 00319 00320 delete[] idx; 00321 00322 GurlsOptionsList* paramsel; 00323 00324 if(opt.hasOpt("paramsel")) 00325 { 00326 GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp"); 00327 tmp_opt->copyOpt("paramsel", opt); 00328 00329 paramsel = GurlsOptionsList::dynacast(tmp_opt->getOpt("paramsel")); 00330 tmp_opt->removeOpt("paramsel", false); 00331 delete tmp_opt; 00332 00333 paramsel->removeOpt("guesses"); 00334 paramsel->removeOpt("perf"); 00335 paramsel->removeOpt("lambdas"); 00336 } 00337 else 00338 paramsel = new GurlsOptionsList("paramsel"); 00339 00340 00341 00342 paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*LAMBDA)); 00343 paramsel->addOpt("perf", new OptMatrix<gMat2D<T> >(*perf)); 00344 00345 //vout.guesses = guesses; 00346 gMat2D<T> *guesses_mat = new gMat2D<T>(guesses, 1, tot, true); 00347 paramsel->addOpt("guesses", new OptMatrix<gMat2D<T> >(*guesses_mat)); 00348 00349 delete[] guesses; 00350 00351 return paramsel; 00352 } 00353 catch( gException& e) 00354 { 00355 00356 for(typename std::set<T*>::iterator it = garbage.begin(); it != garbage.end(); ++it) 00357 delete[] (*it); 00358 00359 throw e; 00360 } 00361 } 00362 00363 } 00364 00365 #endif // _GURLS_LOOCVPRIMAL_H_