GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
rlsprimalrecinit.h
00001  /*
00002   * The GURLS Package in C++
00003   *
00004   * Copyright (C) 2011-1013, IIT@MIT Lab
00005   * All rights reserved.
00006   *
00007   * authors:  M. Santoro
00008   * email:   msantoro@mit.edu
00009   * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010   *
00011   * Redistribution and use in source and binary forms, with or without
00012   * modification, are permitted provided that the following conditions
00013   * are met:
00014   *
00015   *     * Redistributions of source code must retain the above
00016   *       copyright notice, this list of conditions and the following
00017   *       disclaimer.
00018   *     * Redistributions in binary form must reproduce the above
00019   *       copyright notice, this list of conditions and the following
00020   *       disclaimer in the documentation and/or other materials
00021   *       provided with the distribution.
00022   *     * Neither the name(s) of the copyright holders nor the names
00023   *       of its contributors or of the Massacusetts Institute of
00024   *       Technology or of the Italian Institute of Technology may be
00025   *       used to endorse or promote products derived from this software
00026   *       without specific prior written permission.
00027   *
00028   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029   * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030   * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031   * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032   * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033   * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034   * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036   * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037   * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038   * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039   * POSSIBILITY OF SUCH DAMAGE.
00040   */
00041 
00042 
00043 #ifndef _GURLS_RLSPRIMALRECINIT_H_
00044 #define _GURLS_RLSPRIMALRECINIT_H_
00045 
00046 #include "gurls++/optimization.h"
00047 
00048 #include "gurls++/optmatrix.h"
00049 #include "gurls++/optfunction.h"
00050 
00051 #include "gurls++/utils.h"
00052 
00053 namespace gurls
00054 {
00055 
00060 template <typename T>
00061 class RLSPrimalRecInit: public Optimizer<T>
00062 {
00063 
00064 public:
00083     GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00084 };
00085 
00086 
00087 template <typename T>
00088 GurlsOptionsList* RLSPrimalRecInit<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt)
00089 {
00090     //  lambda = opt.singlelambda(opt.paramsel.lambdas);
00091     const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas");
00092     T lambda = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize());
00093 
00094 
00095     //  [n,d] = size(X);
00096     const unsigned long n = opt.hasOpt("nTot")? static_cast<unsigned long>(opt.getOptAsNumber("nTot")) : X.rows();
00097     unsigned long d;
00098     unsigned long t;
00099 
00100 
00101     //  XtX = X'*X;
00102     T* XtX;
00103     if(!opt.hasOpt("kernel.XtX"))
00104     {
00105         d = X.cols();
00106         XtX = new T[d*d];
00107         dot(X.getData(), X.getData(), XtX, n, d, n, d, d, d, CblasTrans, CblasNoTrans, CblasColMajor);
00108     }
00109     else
00110     {
00111         const gMat2D<T>& XtX_mat = opt.getOptValue<OptMatrix<gMat2D<T> > >("kernel.XtX");
00112         d = XtX_mat.cols();
00113         XtX = new T[d*d];
00114         copy(XtX, XtX_mat.getData(), d*d);
00115     }
00116 
00117     //  Xty = X'*y;
00118     T* Xty;
00119     if(!opt.hasOpt("kernel.Xty"))
00120     {
00121         t = Y.cols();
00122         Xty = new T[d*t];
00123         dot(X.getData(), Y.getData(), Xty, n, d, n, t, d, t, CblasTrans, CblasNoTrans, CblasColMajor);
00124     }
00125     else
00126     {
00127         const gMat2D<T>& Xty_mat = opt.getOptValue<OptMatrix<gMat2D<T> > >("kernel.Xty");
00128         t = Xty_mat.cols();
00129         Xty = new T[d*t];
00130         copy(Xty, Xty_mat.getData(), d*t);
00131     }
00132 
00133 
00134     //  Cinv = pinv(XtX + (n*lambda)*eye(d));
00135     T coeff = n*lambda;
00136     axpy(d, (T)1.0, &coeff, 0, XtX, d+1);
00137 
00138     int cinv_rows, cinv_cols;
00139     T* cinv = pinv(XtX, d, d, cinv_rows, cinv_cols);
00140 
00141     delete[] XtX;
00142 
00143     gMat2D<T>*W = new gMat2D<T>(cinv_rows, t);
00144     dot(cinv, Xty, W->getData(), cinv_rows, d, d, t, cinv_rows, t, CblasNoTrans, CblasNoTrans, CblasColMajor);
00145 
00146     gMat2D<T> *Cinv = new gMat2D<T>(cinv_rows, cinv_cols);
00147     copy(Cinv->getData(), cinv, Cinv->getSize());
00148 
00149     delete[] cinv;
00150     delete[] Xty;
00151 
00152     GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer");
00153 
00154     //  cfr.W = Cinv*Xty;
00155     optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W));
00156 
00157     //  cfr.C = [];
00158     gMat2D<T>* emptyC = new gMat2D<T>();
00159     optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*emptyC));
00160 
00161     //  cfr.X = [];
00162     gMat2D<T>* emptyX = new gMat2D<T>();
00163     optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*emptyX));
00164 
00165     //  cfr.Cinv = Cinv;
00166     optimizer->addOpt("Cinv", new OptMatrix<gMat2D<T> >(*Cinv));
00167 
00168     return optimizer;
00169 }
00170 
00171 
00172 }
00173 #endif // _GURLS_RLSPRIMALRECINIT_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends