GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
rlspegasos.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_RLSPEGASOS_H_
00044 #define _GURLS_RLSPEGASOS_H_
00045 
00046 #include "gurls++/optimization.h"
00047 #include "gurls++/utils.h"
00048 
00049 #include <set>
00050 
00051 namespace gurls {
00052 
00058 template <typename T>
00059 class RLSPegasos: public Optimizer<T>{
00060 
00061 public:
00086     GurlsOptionsList *execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt);
00087 };
00088 
00089 
00090 template <typename T>
00091 GurlsOptionsList* RLSPegasos<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt)
00092 {
00093     //  lambda = opt.singlelambda(opt.paramsel.lambdas);
00094     const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas");
00095     T lambda = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize());
00096 
00097 
00098     //   [n,d] = size(X);
00099     const unsigned long n = X.rows();
00100     const unsigned long d = X.cols();
00101 
00102     //   T = size(bY,2);
00103     const unsigned long t = Y.cols();
00104 
00105 
00106     GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer");
00107 
00108     //   opt.cfr.W = zeros(d,T);
00109     gMat2D<T>* W = new gMat2D<T>(d,t);
00110     set(W->getData(), (T)0.0, d*t);
00111     optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W));
00112 
00113     //   opt.cfr.W_sum = zeros(d,T);
00114     gMat2D<T>* W_sum = new gMat2D<T>(d,t);
00115     copy(W_sum->getData(), W->getData(), d*t);
00116     optimizer->addOpt("W_sum", new OptMatrix<gMat2D<T> >(*W_sum));
00117 
00118     optimizer->addOpt("count", new OptNumber(0.0));
00119 
00120 
00121     //   opt.cfr.acc_last = [];
00122     //   opt.cfr.acc_avg = [];
00123 
00124     //           opt.cfr.t0 = ceil(norm(X(1,:))/sqrt(opt.singlelambda(opt.paramsel.lambdas)));
00125     T* row = new T[d];
00126     getRow(X.getData(), n, d, 0, row);
00127     optimizer->addOpt("t0", new OptNumber( ceil( nrm2(d, row, 1)/sqrt(lambda))));
00128 
00129     delete[] row;
00130 
00131 
00132     //   % Run mulitple epochs
00133     //   for i = 1:opt.epochs,
00134     int epochs = static_cast<int>(opt.getOptAsNumber("epochs"));
00135 
00136     GurlsOptionsList* tmp_opt = new GurlsOptionsList("opt");
00137 
00138     GurlsOptionsList* tmp_paramsel = new GurlsOptionsList("paramsel");
00139     tmp_opt->addOpt("paramsel", tmp_paramsel);
00140 
00141     gMat2D<T>* ret_lambdas = new gMat2D<T>(ll);
00142     tmp_paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*ret_lambdas));
00143 
00144 
00145     OptFunction* tmp_singlelambda = new OptFunction(opt.getOptAs<OptFunction>("singlelambda")->getName());
00146     tmp_opt->addOpt("singlelambda", tmp_singlelambda);
00147 
00148     tmp_opt->addOpt("optimizer", optimizer);
00149 
00150     for(int i=0; i<epochs; ++i)
00151     {
00152         //       if opt.cfr.count == 0
00153         //           opt.cfr.t0 = ceil(norm(X(1,:))/sqrt(opt.singlelambda(opt.paramsel.lambdas)));
00154         //           fprintf('\n\tt0 is set to : %f\n', opt.cfr.t0);
00155         //       end
00156 
00157         //       opt.cfr = rls_pegasos_singlepass(X, bY, opt);
00158         GurlsOptionsList* result = rls_pegasos_driver(X.getData(), Y.getData(), *tmp_opt, n, d, Y.rows(), t);
00159 
00160         tmp_opt->removeOpt("optimizer");
00161         tmp_opt->addOpt("optimizer", result);
00162     }
00163 
00164     optimizer = tmp_opt->getOptAs<GurlsOptionsList>("optimizer");
00165     tmp_opt->removeOpt("optimizer", false);
00166     delete tmp_opt;
00167 
00168     //   cfr = opt.cfr;
00169 
00170     //   cfr.W = opt.cfr.W_sum/opt.cfr.count;
00171 
00172     T count = static_cast<T>(optimizer->getOptAsNumber("count"));
00173     if(eq(count, (T)0.0))
00174         throw gException(Exception_Illegal_Argument_Value);
00175 
00176     W = &(optimizer->getOptValue<OptMatrix<gMat2D<T> > >("W"));
00177     W_sum = &(optimizer->getOptValue<OptMatrix<gMat2D<T> > >("W_sum"));
00178 
00179     set(W->getData(), (T)0.0, W->getSize());
00180     axpy(W->getSize(), (T)(1.0/count), W_sum->getData(), 1, W->getData(), 1);
00181 
00182     return optimizer;
00183 }
00184 
00185 }
00186 #endif // _GURLS_RLSPEGASOS_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends