GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
rlsprimalr.h
00001  /*
00002   * The GURLS Package in C++
00003   *
00004   * Copyright (C) 2011-1013, IIT@MIT Lab
00005   * All rights reserved.
00006   *
00007   * authors:  M. Santoro
00008   * email:   msantoro@mit.edu
00009   * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010   *
00011   * Redistribution and use in source and binary forms, with or without
00012   * modification, are permitted provided that the following conditions
00013   * are met:
00014   *
00015   *     * Redistributions of source code must retain the above
00016   *       copyright notice, this list of conditions and the following
00017   *       disclaimer.
00018   *     * Redistributions in binary form must reproduce the above
00019   *       copyright notice, this list of conditions and the following
00020   *       disclaimer in the documentation and/or other materials
00021   *       provided with the distribution.
00022   *     * Neither the name(s) of the copyright holders nor the names
00023   *       of its contributors or of the Massacusetts Institute of
00024   *       Technology or of the Italian Institute of Technology may be
00025   *       used to endorse or promote products derived from this software
00026   *       without specific prior written permission.
00027   *
00028   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029   * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030   * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031   * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032   * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033   * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034   * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036   * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037   * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038   * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039   * POSSIBILITY OF SUCH DAMAGE.
00040   */
00041 
00042 
00043 #ifndef _GURLS_RLSPRIMALR_H_
00044 #define _GURLS_RLSPRIMALR_H_
00045 
00046 #include "gurls++/optimization.h"
00047 #include "gurls++/gmath.h"
00048 #include "gurls++/utils.h"
00049 
00050 namespace gurls {
00051 
00056 template <typename T>
00057 class RLSPrimalr: public Optimizer<T>{
00058 
00059 public:
00076     GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00077 };
00078 
00079 
00080 template <typename T>
00081 GurlsOptionsList* RLSPrimalr<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt)
00082 {
00083     //  lambda = opt.singlelambda(opt.paramsel.lambdas);
00084     const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas");
00085     T lambda = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize());
00086 
00087 
00088 //    std::cout << "Solving primal RLS using Randomized SVD..." << std::endl;
00089 
00090     //  [n,d] = size(X);
00091 
00092     const unsigned long n = X.rows();
00093     const unsigned long d = X.cols();
00094 
00095     const unsigned long Yn = Y.rows();
00096     const unsigned long Yd = Y.cols();
00097 
00098     //  ===================================== Primal K
00099 
00100 //  XtX = X'*X;
00101     T* XtX = new T[d*d];
00102     dot(X.getData(), X.getData(), XtX, n, d, n, d, d, d, CblasTrans, CblasNoTrans, CblasColMajor);
00103 
00104 
00105 //    [Q,L,U] = tygert_svd(XtX,d);
00106 //    Q = double(Q);
00107 //    L = double(diag(L));
00108     T *Q = new T[d*d];
00109     T *L = new T[d];
00110     T *V = NULL;
00111 
00112     unsigned long k = static_cast<unsigned long>(gurls::round((opt.getOptAsNumber("eig_percentage")*d)/100.0));
00113     random_svd(XtX, d, d, Q, L, V, k);
00114 
00115     delete[] XtX;
00116 
00117 //  Xty = X'*y;
00118     T* Xty = new T[d*Yd];
00119     dot(X.getData(), Y.getData(), Xty, n, d, Yn, Yd, d, Yd, CblasTrans, CblasNoTrans, CblasColMajor);
00120 
00121 //    if isfield(opt,'W0')
00122     if(opt.hasOpt("W0"))
00123     {
00124 //        Xty = Xty + opt.W0;
00125         const gMat2D<T>& W0 = OptMatrix< gMat2D<T> >::dynacast(opt.getOpt("W0"))->getValue();
00126 
00127         if(W0.rows() == d && W0.cols() == Yd)
00128             axpy(d*Yd, (T)1.0, W0.getData(), 1, Xty, 1);
00129     }
00130 
00131 
00132     T* QtXty = new T[d*Yd];
00133     dot(Q, Xty, QtXty, d, d, d, Yd, d, Yd, CblasTrans, CblasNoTrans, CblasColMajor);
00134 
00135 //    cfr.W = rls_eigen(Q, L, Q'*Xty, lambda,d);
00136     gMat2D<T>* W = new gMat2D<T>(d, Yd);
00137     T* work = new T[d*(d+1)];
00138     rls_eigen(Q, L, QtXty, W->getData(), lambda, d, d, d, d, d, Yd, work);
00139 
00140 
00141     delete [] QtXty;
00142     delete [] work;
00143     delete [] Xty;
00144     delete [] Q;
00145     delete [] L;
00146 
00147     GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer");
00148 
00149     optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W));
00150 
00151 //    cfr.C = [];
00152     gMat2D<T>* emptyC = new gMat2D<T>();
00153     optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*emptyC));
00154 
00155     //  cfr.X = [];
00156     gMat2D<T>* emptyX = new gMat2D<T>();
00157     optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*emptyX));
00158 
00159     return optimizer;
00160 }
00161 
00162 
00163 }
00164 #endif // _GURLS_RLSPRIMALR_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends