![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_RLSPEGASOS_H_ 00044 #define _GURLS_RLSPEGASOS_H_ 00045 00046 #include "gurls++/optimization.h" 00047 #include "gurls++/utils.h" 00048 00049 #include <set> 00050 00051 namespace gurls { 00052 00058 template <typename T> 00059 class RLSPegasos: public Optimizer<T>{ 00060 00061 public: 00086 GurlsOptionsList *execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt); 00087 }; 00088 00089 00090 template <typename T> 00091 GurlsOptionsList* RLSPegasos<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) 00092 { 00093 // lambda = opt.singlelambda(opt.paramsel.lambdas); 00094 const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas"); 00095 T lambda = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize()); 00096 00097 00098 // [n,d] = size(X); 00099 const unsigned long n = X.rows(); 00100 const unsigned long d = X.cols(); 00101 00102 // T = size(bY,2); 00103 const unsigned long t = Y.cols(); 00104 00105 00106 GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer"); 00107 00108 // opt.cfr.W = zeros(d,T); 00109 gMat2D<T>* W = new gMat2D<T>(d,t); 00110 set(W->getData(), (T)0.0, d*t); 00111 optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W)); 00112 00113 // opt.cfr.W_sum = zeros(d,T); 00114 gMat2D<T>* W_sum = new gMat2D<T>(d,t); 00115 copy(W_sum->getData(), W->getData(), d*t); 00116 optimizer->addOpt("W_sum", new OptMatrix<gMat2D<T> >(*W_sum)); 00117 00118 optimizer->addOpt("count", new OptNumber(0.0)); 00119 00120 00121 // opt.cfr.acc_last = []; 00122 // opt.cfr.acc_avg = []; 00123 00124 // opt.cfr.t0 = ceil(norm(X(1,:))/sqrt(opt.singlelambda(opt.paramsel.lambdas))); 00125 T* row = new T[d]; 00126 getRow(X.getData(), n, d, 0, row); 00127 optimizer->addOpt("t0", new OptNumber( ceil( nrm2(d, row, 1)/sqrt(lambda)))); 00128 00129 delete[] row; 00130 00131 00132 // % Run mulitple epochs 00133 // for i = 1:opt.epochs, 00134 int epochs = static_cast<int>(opt.getOptAsNumber("epochs")); 00135 00136 GurlsOptionsList* tmp_opt = new GurlsOptionsList("opt"); 00137 00138 GurlsOptionsList* tmp_paramsel = new GurlsOptionsList("paramsel"); 00139 tmp_opt->addOpt("paramsel", tmp_paramsel); 00140 00141 gMat2D<T>* ret_lambdas = new gMat2D<T>(ll); 00142 tmp_paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*ret_lambdas)); 00143 00144 00145 OptFunction* tmp_singlelambda = new OptFunction(opt.getOptAs<OptFunction>("singlelambda")->getName()); 00146 tmp_opt->addOpt("singlelambda", tmp_singlelambda); 00147 00148 tmp_opt->addOpt("optimizer", optimizer); 00149 00150 for(int i=0; i<epochs; ++i) 00151 { 00152 // if opt.cfr.count == 0 00153 // opt.cfr.t0 = ceil(norm(X(1,:))/sqrt(opt.singlelambda(opt.paramsel.lambdas))); 00154 // fprintf('\n\tt0 is set to : %f\n', opt.cfr.t0); 00155 // end 00156 00157 // opt.cfr = rls_pegasos_singlepass(X, bY, opt); 00158 GurlsOptionsList* result = rls_pegasos_driver(X.getData(), Y.getData(), *tmp_opt, n, d, Y.rows(), t); 00159 00160 tmp_opt->removeOpt("optimizer"); 00161 tmp_opt->addOpt("optimizer", result); 00162 } 00163 00164 optimizer = tmp_opt->getOptAs<GurlsOptionsList>("optimizer"); 00165 tmp_opt->removeOpt("optimizer", false); 00166 delete tmp_opt; 00167 00168 // cfr = opt.cfr; 00169 00170 // cfr.W = opt.cfr.W_sum/opt.cfr.count; 00171 00172 T count = static_cast<T>(optimizer->getOptAsNumber("count")); 00173 if(eq(count, (T)0.0)) 00174 throw gException(Exception_Illegal_Argument_Value); 00175 00176 W = &(optimizer->getOptValue<OptMatrix<gMat2D<T> > >("W")); 00177 W_sum = &(optimizer->getOptValue<OptMatrix<gMat2D<T> > >("W_sum")); 00178 00179 set(W->getData(), (T)0.0, W->getSize()); 00180 axpy(W->getSize(), (T)(1.0/count), W_sum->getData(), 1, W->getData(), 1); 00181 00182 return optimizer; 00183 } 00184 00185 } 00186 #endif // _GURLS_RLSPEGASOS_H_