GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
rlsprimalrecupdate.h
00001  /*
00002   * The GURLS Package in C++
00003   *
00004   * Copyright (C) 2011-1013, IIT@MIT Lab
00005   * All rights reserved.
00006   *
00007   * authors:  M. Santoro
00008   * email:   msantoro@mit.edu
00009   * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010   *
00011   * Redistribution and use in source and binary forms, with or without
00012   * modification, are permitted provided that the following conditions
00013   * are met:
00014   *
00015   *     * Redistributions of source code must retain the above
00016   *       copyright notice, this list of conditions and the following
00017   *       disclaimer.
00018   *     * Redistributions in binary form must reproduce the above
00019   *       copyright notice, this list of conditions and the following
00020   *       disclaimer in the documentation and/or other materials
00021   *       provided with the distribution.
00022   *     * Neither the name(s) of the copyright holders nor the names
00023   *       of its contributors or of the Massacusetts Institute of
00024   *       Technology or of the Italian Institute of Technology may be
00025   *       used to endorse or promote products derived from this software
00026   *       without specific prior written permission.
00027   *
00028   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029   * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030   * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031   * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032   * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033   * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034   * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036   * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037   * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038   * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039   * POSSIBILITY OF SUCH DAMAGE.
00040   */
00041 
00042 
00043 #ifndef _GURLS_RLSPRIMALRECUPDATE_H_
00044 #define _GURLS_RLSPRIMALRECUPDATE_H_
00045 
00046 #include "gurls++/optimization.h"
00047 
00048 #include "gurls++/optmatrix.h"
00049 #include "gurls++/optfunction.h"
00050 
00051 #include "gurls++/utils.h"
00052 
00053 namespace gurls
00054 {
00055 
00060 template <typename T>
00061 class RLSPrimalRecUpdate: public Optimizer<T>
00062 {
00063 public:
00080     GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00081 };
00082 
00083 
00084 template <typename T>
00085 GurlsOptionsList* RLSPrimalRecUpdate<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt)
00086 {
00087     //  [n,d] = size(X);
00088 
00089     const unsigned long n = X.rows();
00090     const unsigned long d = X.cols();
00091 
00092     const unsigned long t = Y.cols();
00093 
00094     //  W = opt.rls.W;
00095     const gMat2D<T>& prev_W = opt.getOptValue<OptMatrix<gMat2D<T> > >("optimizer.W");
00096     gMat2D<T>* W = new gMat2D<T>(prev_W);
00097 
00098     //  Cinv = opt.rls.Cinv;
00099     const gMat2D<T>& prev_Cinv = opt.getOptValue<OptMatrix<gMat2D<T> > >("optimizer.Cinv");
00100     gMat2D<T>* Cinv = new gMat2D<T>(prev_Cinv);
00101 
00102     T* WData = W->getData();
00103     const unsigned long wn = W->rows();
00104     const unsigned long wd = W->cols();
00105 
00106 
00107     T* CinvData = Cinv->getData();
00108     const unsigned long cn = Cinv->rows();
00109     const unsigned long cd = Cinv->cols();
00110 
00111     T* Cx = new T[cn];
00112     T* x = new T[d];
00113     T* y = new T[t];
00114     T xCx;
00115     T* CxCxt = new T[cn*cn];
00116     T* xW = new T[wd];
00117     T* Cxy = new T[cn*t];
00118 
00119     for(unsigned long i=0; i<n; ++i)
00120     {
00121         getRow(X.getData(), n, d, i, x);
00122         getRow(Y.getData(), n, t, i, y);
00123 
00124         //  Cx = Cinv*X(i,:)';
00125         gemv(CblasNoTrans, cn, cd, (T)1.0, CinvData, cn, x, 1, (T)0.0, Cx, 1);
00126 
00127         //  xCx = X(i,:)*Cx;
00128         gemv(CblasNoTrans, 1, d, (T)1.0, x, 1, Cx, 1, (T)0.0, &xCx, 1);
00129 
00130         //  Cinv = Cinv - Cx*Cx'./(1+xCx);
00131         dot(Cx, Cx, CxCxt, cn, 1, cn, 1, cn, cn, CblasNoTrans, CblasTrans, CblasColMajor);
00132         axpy(cn*cn, (T)(-1.0/(xCx+1)), CxCxt, 1, CinvData, 1);
00133 
00134 
00135         //  W = W +(Cx*(y(i,:)-X(i,:)*W))./(1+xCx);
00136         dot(x, WData, xW, 1, d, wn, wd, 1, wd, CblasNoTrans, CblasNoTrans, CblasColMajor);
00137         axpy(t, (T)-1.0, xW, 1, y, 1);
00138         dot(Cx, y, Cxy, cn, 1, 1, t, cn, t, CblasNoTrans, CblasNoTrans, CblasColMajor);
00139         axpy(cn*t, (T)(1.0/(xCx+1)), Cxy, 1, WData, 1);
00140     }
00141 
00142     delete[] Cx;
00143     delete[] x;
00144     delete[] y;
00145     delete[] CxCxt;
00146     delete[] xW;
00147     delete[] Cxy;
00148 
00149 
00150     GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer");
00151 
00152     //  rls.W = W;
00153     optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W));
00154 
00155     //  rls.C = [];
00156     gMat2D<T>* emptyC = new gMat2D<T>();
00157     optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*emptyC));
00158 
00159     //  cfr.X = [];
00160     gMat2D<T>* emptyX = new gMat2D<T>();
00161     optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*emptyX));
00162 
00163     //  rls.Cinv = Cinv;
00164     optimizer->addOpt("Cinv", new OptMatrix<gMat2D<T> >(*Cinv));
00165 
00166     return optimizer;
00167 }
00168 
00169 
00170 }
00171 #endif // _GURLS_RLSPRIMALRECUPDATE_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends