![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_LOOCVDUAL_H_ 00044 #define _GURLS_LOOCVDUAL_H_ 00045 00046 #include <cstdio> 00047 #include <cstring> 00048 #include <iostream> 00049 #include <cmath> 00050 #include <algorithm> 00051 #include <set> 00052 00053 #include "gurls++/options.h" 00054 #include "gurls++/optlist.h" 00055 #include "gurls++/gmat2d.h" 00056 #include "gurls++/gvec.h" 00057 #include "gurls++/gmath.h" 00058 #include "gurls++/utils.h" 00059 00060 #include "gurls++/paramsel.h" 00061 #include "gurls++/perf.h" 00062 00063 namespace gurls { 00064 00070 template <typename T> 00071 class ParamSelLoocvDual: public ParamSelection<T>{ 00072 00073 public: 00090 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00091 }; 00092 00093 template <typename T> 00094 GurlsOptionsList* ParamSelLoocvDual<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt) 00095 { 00096 // [n,T] = size(y); 00097 const unsigned long n = Y.rows(); 00098 const unsigned long t = Y.cols(); 00099 00100 const unsigned long d = X.cols(); 00101 00102 00103 // tot = opt.nlambda; 00104 int tot = static_cast<int>(std::ceil( opt.getOptAsNumber("nlambda"))); 00105 00106 00107 // [Q,L] = eig(opt.kernel.K); 00108 // Q = double(Q); 00109 // L = double(diag(L)); 00110 const GurlsOptionsList* kernel = opt.getOptAs<GurlsOptionsList>("kernel"); 00111 00112 const gMat2D<T> &K_mat = kernel->getOptValue<OptMatrix<gMat2D<T> > >("K"); 00113 00114 gMat2D<T> K(K_mat.rows(), K_mat.cols()); 00115 copy(K.getData(), K_mat.getData(), K_mat.getSize()); 00116 00117 const unsigned long qrows = K.rows(); 00118 const unsigned long qcols = K.cols(); 00119 const unsigned long l_length = qrows; 00120 00121 T *Q = K.getData(); 00122 T *L = new T[l_length]; 00123 00124 eig_sm(Q, L, qrows); // qrows == qcols 00125 00126 int r = n; 00127 if(kernel->getOptAsString("type") == "linear") 00128 { 00129 set(L, (T) 1.0e-12, l_length-d); 00130 r = std::min(n,d); 00131 } 00132 00133 00134 // Qty = Q'*y; 00135 // T* Qty = new T[qrows*qcols]; 00136 T* Qty = new T[qcols*t]; 00137 dot(Q, Y.getData(), Qty, qrows, qcols, n, t, qcols, t, CblasTrans, CblasNoTrans, CblasColMajor); 00138 00139 T* guesses = lambdaguesses(L, n, r, n, tot, (T)(opt.getOptAsNumber("smallnumber"))); 00140 00141 00142 00143 GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested"); 00144 00145 gMat2D<T>* pred = new gMat2D<T>(n, t); 00146 OptMatrix<gMat2D<T> >* pred_opt = new OptMatrix<gMat2D<T> >(*pred); 00147 nestedOpt->addOpt("pred", pred_opt); 00148 00149 00150 Performance<T>* perfClass = Performance<T>::factory(opt.getOptAsString("hoperf")); 00151 00152 gMat2D<T>* perf = new gMat2D<T>(tot, t); 00153 T* ap = perf->getData(); 00154 00155 T* C_div_Z = new T[qrows]; 00156 T* C = new T[qrows*qcols]; 00157 T* Z = new T[qrows]; 00158 T* work = new T[std::max((qrows+1)*l_length, (qrows*qcols)+l_length)]; 00159 00160 for(int i = 0; i < tot; ++i) 00161 { 00162 rls_eigen(Q, L, Qty, C, guesses[i], n, qrows, qcols, l_length, qcols, t, work); 00163 GInverseDiagonal(Q, L, guesses+i, Z, qrows, qcols, l_length, 1, work); 00164 00165 for(unsigned long j = 0; j< t; ++j) 00166 { 00167 rdivide(C + (qrows*j), Z, C_div_Z, qrows); 00168 00169 // opt.pred(:,t) = y(:,t) - (C(:,t)./Z); 00170 copy(pred->getData()+(n*j), Y.getData() + (n*j), n); 00171 axpy(n, (T)-1.0, C_div_Z, 1, pred->getData() + (n*j), 1); 00172 } 00173 00174 // opt.perf = opt.hoperf([],y,opt); 00175 const gMat2D<T> dummy; 00176 GurlsOptionsList* perf_opt = perfClass->execute(dummy, Y, *nestedOpt); 00177 00178 gMat2D<T> &forho_vec = perf_opt->getOptValue<OptMatrix<gMat2D<T> > >("forho"); 00179 00180 copy(ap+i, forho_vec.getData(), t, tot, 1); 00181 00182 delete perf_opt; 00183 } 00184 00185 delete nestedOpt; 00186 delete[] work; 00187 delete [] C; 00188 delete [] Z; 00189 delete [] C_div_Z; 00190 delete perfClass; 00191 delete [] Qty; 00192 00193 delete[] L; 00194 //delete[] Q; 00195 00196 unsigned long* idx = new unsigned long[t]; 00197 work = NULL; 00198 indicesOfMax(ap, tot, t, idx, work, 1); 00199 00200 00201 gMat2D<T> *LAMBDA = new gMat2D<T>(1, t); 00202 copyLocations(idx, guesses, t, tot, LAMBDA->getData()); 00203 00204 delete[] idx; 00205 00206 00207 GurlsOptionsList* paramsel; 00208 00209 if(opt.hasOpt("paramsel")) 00210 { 00211 GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp"); 00212 tmp_opt->copyOpt("paramsel", opt); 00213 00214 paramsel = GurlsOptionsList::dynacast(tmp_opt->getOpt("paramsel")); 00215 tmp_opt->removeOpt("paramsel", false); 00216 delete tmp_opt; 00217 00218 paramsel->removeOpt("guesses"); 00219 paramsel->removeOpt("perf"); 00220 paramsel->removeOpt("lambdas"); 00221 } 00222 else 00223 paramsel = new GurlsOptionsList("paramsel"); 00224 00225 00226 // opt.addOpt("lambdas", LAMBDA); 00227 paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*LAMBDA)); 00228 00229 //vout.perf = ap; 00230 paramsel->addOpt("perf", new OptMatrix<gMat2D<T> >(*perf)); 00231 00232 //vout.guesses = guesses; 00233 gMat2D<T> *guesses_mat = new gMat2D<T>(guesses, 1, tot, true); 00234 paramsel->addOpt("guesses", new OptMatrix<gMat2D<T> >(*guesses_mat)); 00235 00236 delete[] guesses; 00237 00238 return paramsel; 00239 00240 } 00241 00242 00243 } 00244 00245 #endif // _GURLS_LOOCVDUAL_H_