![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_RLSPRIMALRECUPDATE_H_ 00044 #define _GURLS_RLSPRIMALRECUPDATE_H_ 00045 00046 #include "gurls++/optimization.h" 00047 00048 #include "gurls++/optmatrix.h" 00049 #include "gurls++/optfunction.h" 00050 00051 #include "gurls++/utils.h" 00052 00053 namespace gurls 00054 { 00055 00060 template <typename T> 00061 class RLSPrimalRecUpdate: public Optimizer<T> 00062 { 00063 public: 00080 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00081 }; 00082 00083 00084 template <typename T> 00085 GurlsOptionsList* RLSPrimalRecUpdate<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt) 00086 { 00087 // [n,d] = size(X); 00088 00089 const unsigned long n = X.rows(); 00090 const unsigned long d = X.cols(); 00091 00092 const unsigned long t = Y.cols(); 00093 00094 // W = opt.rls.W; 00095 const gMat2D<T>& prev_W = opt.getOptValue<OptMatrix<gMat2D<T> > >("optimizer.W"); 00096 gMat2D<T>* W = new gMat2D<T>(prev_W); 00097 00098 // Cinv = opt.rls.Cinv; 00099 const gMat2D<T>& prev_Cinv = opt.getOptValue<OptMatrix<gMat2D<T> > >("optimizer.Cinv"); 00100 gMat2D<T>* Cinv = new gMat2D<T>(prev_Cinv); 00101 00102 T* WData = W->getData(); 00103 const unsigned long wn = W->rows(); 00104 const unsigned long wd = W->cols(); 00105 00106 00107 T* CinvData = Cinv->getData(); 00108 const unsigned long cn = Cinv->rows(); 00109 const unsigned long cd = Cinv->cols(); 00110 00111 T* Cx = new T[cn]; 00112 T* x = new T[d]; 00113 T* y = new T[t]; 00114 T xCx; 00115 T* CxCxt = new T[cn*cn]; 00116 T* xW = new T[wd]; 00117 T* Cxy = new T[cn*t]; 00118 00119 for(unsigned long i=0; i<n; ++i) 00120 { 00121 getRow(X.getData(), n, d, i, x); 00122 getRow(Y.getData(), n, t, i, y); 00123 00124 // Cx = Cinv*X(i,:)'; 00125 gemv(CblasNoTrans, cn, cd, (T)1.0, CinvData, cn, x, 1, (T)0.0, Cx, 1); 00126 00127 // xCx = X(i,:)*Cx; 00128 gemv(CblasNoTrans, 1, d, (T)1.0, x, 1, Cx, 1, (T)0.0, &xCx, 1); 00129 00130 // Cinv = Cinv - Cx*Cx'./(1+xCx); 00131 dot(Cx, Cx, CxCxt, cn, 1, cn, 1, cn, cn, CblasNoTrans, CblasTrans, CblasColMajor); 00132 axpy(cn*cn, (T)(-1.0/(xCx+1)), CxCxt, 1, CinvData, 1); 00133 00134 00135 // W = W +(Cx*(y(i,:)-X(i,:)*W))./(1+xCx); 00136 dot(x, WData, xW, 1, d, wn, wd, 1, wd, CblasNoTrans, CblasNoTrans, CblasColMajor); 00137 axpy(t, (T)-1.0, xW, 1, y, 1); 00138 dot(Cx, y, Cxy, cn, 1, 1, t, cn, t, CblasNoTrans, CblasNoTrans, CblasColMajor); 00139 axpy(cn*t, (T)(1.0/(xCx+1)), Cxy, 1, WData, 1); 00140 } 00141 00142 delete[] Cx; 00143 delete[] x; 00144 delete[] y; 00145 delete[] CxCxt; 00146 delete[] xW; 00147 delete[] Cxy; 00148 00149 00150 GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer"); 00151 00152 // rls.W = W; 00153 optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W)); 00154 00155 // rls.C = []; 00156 gMat2D<T>* emptyC = new gMat2D<T>(); 00157 optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*emptyC)); 00158 00159 // cfr.X = []; 00160 gMat2D<T>* emptyX = new gMat2D<T>(); 00161 optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*emptyX)); 00162 00163 // rls.Cinv = Cinv; 00164 optimizer->addOpt("Cinv", new OptMatrix<gMat2D<T> >(*Cinv)); 00165 00166 return optimizer; 00167 } 00168 00169 00170 } 00171 #endif // _GURLS_RLSPRIMALRECUPDATE_H_