![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_RLSPRIMALR_H_ 00044 #define _GURLS_RLSPRIMALR_H_ 00045 00046 #include "gurls++/optimization.h" 00047 #include "gurls++/gmath.h" 00048 #include "gurls++/utils.h" 00049 00050 namespace gurls { 00051 00056 template <typename T> 00057 class RLSPrimalr: public Optimizer<T>{ 00058 00059 public: 00076 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00077 }; 00078 00079 00080 template <typename T> 00081 GurlsOptionsList* RLSPrimalr<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) 00082 { 00083 // lambda = opt.singlelambda(opt.paramsel.lambdas); 00084 const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas"); 00085 T lambda = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize()); 00086 00087 00088 // std::cout << "Solving primal RLS using Randomized SVD..." << std::endl; 00089 00090 // [n,d] = size(X); 00091 00092 const unsigned long n = X.rows(); 00093 const unsigned long d = X.cols(); 00094 00095 const unsigned long Yn = Y.rows(); 00096 const unsigned long Yd = Y.cols(); 00097 00098 // ===================================== Primal K 00099 00100 // XtX = X'*X; 00101 T* XtX = new T[d*d]; 00102 dot(X.getData(), X.getData(), XtX, n, d, n, d, d, d, CblasTrans, CblasNoTrans, CblasColMajor); 00103 00104 00105 // [Q,L,U] = tygert_svd(XtX,d); 00106 // Q = double(Q); 00107 // L = double(diag(L)); 00108 T *Q = new T[d*d]; 00109 T *L = new T[d]; 00110 T *V = NULL; 00111 00112 unsigned long k = static_cast<unsigned long>(gurls::round((opt.getOptAsNumber("eig_percentage")*d)/100.0)); 00113 random_svd(XtX, d, d, Q, L, V, k); 00114 00115 delete[] XtX; 00116 00117 // Xty = X'*y; 00118 T* Xty = new T[d*Yd]; 00119 dot(X.getData(), Y.getData(), Xty, n, d, Yn, Yd, d, Yd, CblasTrans, CblasNoTrans, CblasColMajor); 00120 00121 // if isfield(opt,'W0') 00122 if(opt.hasOpt("W0")) 00123 { 00124 // Xty = Xty + opt.W0; 00125 const gMat2D<T>& W0 = OptMatrix< gMat2D<T> >::dynacast(opt.getOpt("W0"))->getValue(); 00126 00127 if(W0.rows() == d && W0.cols() == Yd) 00128 axpy(d*Yd, (T)1.0, W0.getData(), 1, Xty, 1); 00129 } 00130 00131 00132 T* QtXty = new T[d*Yd]; 00133 dot(Q, Xty, QtXty, d, d, d, Yd, d, Yd, CblasTrans, CblasNoTrans, CblasColMajor); 00134 00135 // cfr.W = rls_eigen(Q, L, Q'*Xty, lambda,d); 00136 gMat2D<T>* W = new gMat2D<T>(d, Yd); 00137 T* work = new T[d*(d+1)]; 00138 rls_eigen(Q, L, QtXty, W->getData(), lambda, d, d, d, d, d, Yd, work); 00139 00140 00141 delete [] QtXty; 00142 delete [] work; 00143 delete [] Xty; 00144 delete [] Q; 00145 delete [] L; 00146 00147 GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer"); 00148 00149 optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W)); 00150 00151 // cfr.C = []; 00152 gMat2D<T>* emptyC = new gMat2D<T>(); 00153 optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*emptyC)); 00154 00155 // cfr.X = []; 00156 gMat2D<T>* emptyX = new gMat2D<T>(); 00157 optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*emptyX)); 00158 00159 return optimizer; 00160 } 00161 00162 00163 } 00164 #endif // _GURLS_RLSPRIMALR_H_