![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_RLSPRIMALRECINIT_H_ 00044 #define _GURLS_RLSPRIMALRECINIT_H_ 00045 00046 #include "gurls++/optimization.h" 00047 00048 #include "gurls++/optmatrix.h" 00049 #include "gurls++/optfunction.h" 00050 00051 #include "gurls++/utils.h" 00052 00053 namespace gurls 00054 { 00055 00060 template <typename T> 00061 class RLSPrimalRecInit: public Optimizer<T> 00062 { 00063 00064 public: 00083 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00084 }; 00085 00086 00087 template <typename T> 00088 GurlsOptionsList* RLSPrimalRecInit<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt) 00089 { 00090 // lambda = opt.singlelambda(opt.paramsel.lambdas); 00091 const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas"); 00092 T lambda = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize()); 00093 00094 00095 // [n,d] = size(X); 00096 const unsigned long n = opt.hasOpt("nTot")? static_cast<unsigned long>(opt.getOptAsNumber("nTot")) : X.rows(); 00097 unsigned long d; 00098 unsigned long t; 00099 00100 00101 // XtX = X'*X; 00102 T* XtX; 00103 if(!opt.hasOpt("kernel.XtX")) 00104 { 00105 d = X.cols(); 00106 XtX = new T[d*d]; 00107 dot(X.getData(), X.getData(), XtX, n, d, n, d, d, d, CblasTrans, CblasNoTrans, CblasColMajor); 00108 } 00109 else 00110 { 00111 const gMat2D<T>& XtX_mat = opt.getOptValue<OptMatrix<gMat2D<T> > >("kernel.XtX"); 00112 d = XtX_mat.cols(); 00113 XtX = new T[d*d]; 00114 copy(XtX, XtX_mat.getData(), d*d); 00115 } 00116 00117 // Xty = X'*y; 00118 T* Xty; 00119 if(!opt.hasOpt("kernel.Xty")) 00120 { 00121 t = Y.cols(); 00122 Xty = new T[d*t]; 00123 dot(X.getData(), Y.getData(), Xty, n, d, n, t, d, t, CblasTrans, CblasNoTrans, CblasColMajor); 00124 } 00125 else 00126 { 00127 const gMat2D<T>& Xty_mat = opt.getOptValue<OptMatrix<gMat2D<T> > >("kernel.Xty"); 00128 t = Xty_mat.cols(); 00129 Xty = new T[d*t]; 00130 copy(Xty, Xty_mat.getData(), d*t); 00131 } 00132 00133 00134 // Cinv = pinv(XtX + (n*lambda)*eye(d)); 00135 T coeff = n*lambda; 00136 axpy(d, (T)1.0, &coeff, 0, XtX, d+1); 00137 00138 int cinv_rows, cinv_cols; 00139 T* cinv = pinv(XtX, d, d, cinv_rows, cinv_cols); 00140 00141 delete[] XtX; 00142 00143 gMat2D<T>*W = new gMat2D<T>(cinv_rows, t); 00144 dot(cinv, Xty, W->getData(), cinv_rows, d, d, t, cinv_rows, t, CblasNoTrans, CblasNoTrans, CblasColMajor); 00145 00146 gMat2D<T> *Cinv = new gMat2D<T>(cinv_rows, cinv_cols); 00147 copy(Cinv->getData(), cinv, Cinv->getSize()); 00148 00149 delete[] cinv; 00150 delete[] Xty; 00151 00152 GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer"); 00153 00154 // cfr.W = Cinv*Xty; 00155 optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W)); 00156 00157 // cfr.C = []; 00158 gMat2D<T>* emptyC = new gMat2D<T>(); 00159 optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*emptyC)); 00160 00161 // cfr.X = []; 00162 gMat2D<T>* emptyX = new gMat2D<T>(); 00163 optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*emptyX)); 00164 00165 // cfr.Cinv = Cinv; 00166 optimizer->addOpt("Cinv", new OptMatrix<gMat2D<T> >(*Cinv)); 00167 00168 return optimizer; 00169 } 00170 00171 00172 } 00173 #endif // _GURLS_RLSPRIMALRECINIT_H_