![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_RLSDUAL_H_ 00044 #define _GURLS_RLSDUAL_H_ 00045 00046 #include "gurls++/optimization.h" 00047 00048 #include <set> 00049 00050 namespace gurls { 00051 00057 template <typename T> 00058 class RLSDual: public Optimizer<T>{ 00059 00060 public: 00078 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00079 }; 00080 00081 00082 template <typename T> 00083 GurlsOptionsList* RLSDual<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) 00084 { 00085 // lambda = opt.singlelambda(opt.paramsel.lambdas); 00086 const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas"); 00087 T lambda = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize()); 00088 00089 const GurlsOptionsList* kernel = opt.getOptAs<GurlsOptionsList>("kernel"); 00090 const gMat2D<T>& K_mat = kernel->getOptValue<OptMatrix<gMat2D<T> > >("K"); 00091 00092 T* K = new T[K_mat.getSize()]; 00093 copy(K, K_mat.getData(), K_mat.getSize()); 00094 00095 //n = size(opt.kernel.K,1); 00096 const long n = K_mat.rows(); 00097 00098 //T = size(y,2); 00099 const long t = Y.cols(); 00100 00101 00102 // std::cout << "Solving dual RLS... " << std::endl; 00103 00104 const T coeff = n*static_cast<T>(lambda); 00105 long i=0; 00106 for(T* it = K; i<n; ++i, it+=n+1) 00107 *it += coeff; 00108 00109 00110 std::set<T*> garbage; 00111 00112 gMat2D<T>* retC = NULL; 00113 00114 try // Try solving it with cholesky first. 00115 { 00116 // R = chol(K); 00117 T* R = new T[n*n]; 00118 garbage.insert(R); 00119 cholesky(K, n, n, R); 00120 00121 // cfr.C = R\(R'\y); 00122 retC = new gMat2D<T>(Y.rows(), t); 00123 00124 copy(retC->getData(), Y.getData(), Y.getSize()); 00125 mldivide_squared(R, retC->getData(), n, n, retC->rows(), retC->cols(), CblasTrans); 00126 mldivide_squared(R, retC->getData(), n, n, retC->rows(), retC->cols(), CblasNoTrans); 00127 00128 delete[] R; 00129 garbage.erase(R); 00130 } 00131 catch (gException& /*gex*/) 00132 { 00133 for(typename std::set<T*>::iterator it = garbage.begin(); it != garbage.end(); ++it) 00134 delete[] (*it); 00135 00136 garbage.clear(); 00137 00138 if(retC != NULL) 00139 delete retC; 00140 00141 00142 // [Q,L,V] = svd(K); 00143 // Q = double(Q); 00144 // L = double(diag(L)); 00145 T *Q, *L, *V; 00146 int Q_rows, Q_cols; 00147 int L_len; 00148 int V_rows, V_cols; 00149 00150 svd(K, Q, L, V, n, n, Q_rows, Q_cols, L_len, V_rows, V_cols); 00151 00152 00153 // cfr.C = rls_eigen(Q,L,y,lambda,n); 00154 retC = new gMat2D<T>(Q_rows, Y.cols()); 00155 00156 T* Qty = new T[Q_cols*Y.cols()]; 00157 dot(Q, Y.getData(), Qty, Q_rows, Q_cols, Y.rows(), Y.cols(), Q_cols, Y.cols(), CblasTrans, CblasNoTrans, CblasColMajor); 00158 00159 T* work = new T[L_len*(Q_rows+1)]; 00160 rls_eigen(Q, L, Qty, retC->getData(), lambda, n, Q_rows, Q_cols, L_len, Q_cols, Y.cols(), work); 00161 00162 delete [] work; 00163 delete [] Qty; 00164 delete [] Q; 00165 delete [] L; 00166 delete [] V; 00167 } 00168 00169 delete[] K; 00170 00171 GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer"); 00172 00173 // if strcmp(opt.kernel.type, 'linear') 00174 if(kernel->getOptAsString("type") == "linear") 00175 { 00176 // cfr.W = X'*cfr.C; 00177 gMat2D<T>* W = new gMat2D<T>(X.cols(), retC->cols()); 00178 dot(X.getData(), retC->getData(), W->getData(), X.rows(), X.cols(), retC->rows(), retC->cols(), W->rows(), W->cols(), CblasTrans, CblasNoTrans, CblasColMajor); 00179 optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W)); 00180 00181 // cfr.C = []; 00182 gMat2D<T>* emptyC = new gMat2D<T>(); 00183 optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*emptyC)); 00184 00185 // cfr.X = []; 00186 gMat2D<T>* emptyX = new gMat2D<T>(); 00187 optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*emptyX)); 00188 00189 delete retC; 00190 } 00191 else 00192 { 00193 // cfr.W = []; 00194 gMat2D<T>* emptyW = new gMat2D<T>(); 00195 optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*emptyW)); 00196 00197 // cfr.C = retC; 00198 optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*retC)); 00199 00200 // cfr.X = X; 00201 gMat2D<T>* optX = new gMat2D<T>(X); 00202 optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*optX)); 00203 } 00204 00205 return optimizer; 00206 } 00207 00208 } 00209 #endif // _GURLS_RLSDUAL_H_