![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, Matteo Santoro 00005 * All rights reserved. 00006 * 00007 * author: M. Santoro 00008 * email: matteo.santoro@gmail.com 00009 * 00010 * Redistribution and use in source and binary forms, with or without 00011 * modification, are permitted provided that the following conditions 00012 * are met: 00013 * 00014 * * Redistributions of source code must retain the above 00015 * copyright notice, this list of conditions and the following 00016 * disclaimer. 00017 * * Redistributions in binary form must reproduce the above 00018 * copyright notice, this list of conditions and the following 00019 * disclaimer in the documentation and/or other materials 00020 * provided with the distribution. 00021 * * Neither the name(s) of the copyright holders nor the names 00022 * of its contributors or of the Massacusetts Institute of 00023 * Technology or of the Italian Institute of Technology may be 00024 * used to endorse or promote products derived from this software 00025 * without specific prior written permission. 00026 * 00027 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00028 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00029 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00030 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00031 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00032 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00033 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00034 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00035 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00036 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00037 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00038 * POSSIBILITY OF SUCH DAMAGE. 00039 */ 00040 00041 00042 #ifndef _GURLS_GURLS_H_ 00043 #define _GURLS_GURLS_H_ 00044 00045 #include <iostream> 00046 #include <string> 00047 #include <map> 00048 #include <vector> 00049 #include <exception> 00050 #include <ctime> 00051 00052 #include <boost/date_time/posix_time/posix_time_types.hpp> 00053 00054 #include "gurls++/exports.h" 00055 #include "gurls++/exceptions.h" 00056 #include "gurls++/gmat2d.h" 00057 #include "gurls++/optlist.h" 00058 #include "gurls++/options.h" 00059 #include "gurls++/optarray.h" 00060 #include "gurls++/optfunction.h" 00061 #include "gurls++/optmatrix.h" 00062 00063 #include "gurls++/linearkernel.h" 00064 #include "gurls++/rbfkernel.h" 00065 #include "gurls++/chisquaredkernel.h" 00066 00067 #include "gurls++/predkerneltraintest.h" 00068 00069 #include "gurls++/precisionrecall.h" 00070 #include "gurls++/macroavg.h" 00071 #include "gurls++/rmse.h" 00072 00073 #include "gurls++/rlsauto.h" 00074 #include "gurls++/rlsprimal.h" 00075 #include "gurls++/rlsprimalr.h" 00076 #include "gurls++/rlsdual.h" 00077 #include "gurls++/rlsdualr.h" 00078 #include "gurls++/rlspegasos.h" 00079 #include "gurls++/rlsgp.h" 00080 #include "gurls++/rlsprimalrecinit.h" 00081 #include "gurls++/rlsprimalrecupdate.h" 00082 #include "gurls++/rlsrandfeats.h" 00083 00084 #include "gurls++/loocvprimal.h" 00085 #include "gurls++/loocvdual.h" 00086 #include "gurls++/fixlambda.h" 00087 #include "gurls++/fixsiglam.h" 00088 #include "gurls++/siglam.h" 00089 #include "gurls++/siglamho.h" 00090 #include "gurls++/hoprimal.h" 00091 #include "gurls++/hodual.h" 00092 00093 #include "gurls++/hogpregr.h" 00094 #include "gurls++/loogpregr.h" 00095 #include "gurls++/siglamhogpregr.h" 00096 #include "gurls++/siglamloogpregr.h" 00097 00098 #include "gurls++/pred.h" 00099 #include "gurls++/primal.h" 00100 #include "gurls++/dual.h" 00101 #include "gurls++/predgp.h" 00102 #include "gurls++/predrandfeats.h" 00103 00104 #include "gurls++/norml2.h" 00105 #include "gurls++/normtestzscore.h" 00106 #include "gurls++/normzscore.h" 00107 00108 #include "gurls++/splitho.h" 00109 00110 #include "gurls++/boltzman.h" 00111 #include "gurls++/boltzmangap.h" 00112 #include "gurls++/gap.h" 00113 #include "gurls++/maxscore.h" 00114 00115 namespace gurls { 00116 00122 class GURLS_EXPORT GURLS 00123 { 00124 00125 public: 00126 00130 // enum Action {ignore, compute, computeNsave, load, remove}; 00131 00132 typedef OptProcess::Action Action; 00133 00134 static const Action ignore = OptProcess::ignore; 00135 static const Action compute = OptProcess::compute; 00136 static const Action computeNsave = OptProcess::computeNsave; 00137 static const Action load = OptProcess::load; 00138 static const Action remove = OptProcess::remove; 00139 00149 template <typename T> 00150 void run(const gMat2D<T>& X, const gMat2D<T>& y, 00151 GurlsOptionsList& opt, std::string processid); 00152 00153 }; 00154 00155 00156 template <typename T> 00157 void GURLS::run(const gMat2D<T>& X, const gMat2D<T>& y, 00158 GurlsOptionsList& opt, std::string processid) 00159 { 00160 00161 boost::posix_time::ptime begin, end; 00162 boost::posix_time::time_duration diff; 00163 00164 Optimizer<T> *taskOpt; 00165 ParamSelection<T> *taskParSel; 00166 Prediction<T> *taskPrediction; 00167 Performance<T> *taskPerformance; 00168 Kernel<T> *taskKernel; 00169 Norm<T> *taskNorm; 00170 Split<T> *taskSplit; 00171 PredKernel<T> *taskPredKernel; 00172 Confidence<T> *taskConfidence; 00173 00174 // try{ 00175 00176 OptTaskSequence* seq = OptTaskSequence::dynacast(opt.getOpt("seq")); 00177 GurlsOptionsList& processes = *GurlsOptionsList::dynacast(opt.getOpt("processes")); 00178 00179 if (!processes.hasOpt(processid)) 00180 throw gException(Exception_Gurls_Invalid_ProcessID); 00181 00182 // std::vector<double> process = OptNumberList::dynacast( processes.getOpt(processid) )->getValue(); 00183 OptProcess* process = processes.getOptAs<OptProcess>(processid); 00184 00185 // if ((long)process.size() != seq->size()) 00186 if ( process->size() != seq->size()) 00187 throw gException(gurls::Exception_Gurls_Inconsistent_Processes_Number); 00188 00189 const std::string saveFile = opt.getOptAsString("savefile"); 00190 00191 GurlsOptionsList* loadOpt = new GurlsOptionsList("load"); 00192 try 00193 { 00194 loadOpt->load(saveFile); 00195 } 00196 catch(gException & /*ex*/) 00197 { 00198 delete loadOpt; 00199 loadOpt = NULL; 00200 } 00201 00202 GurlsOption* tmpOpt; 00203 00204 //% Load and copy 00205 00206 //if exist(opt.savefile) == 2 00207 // t = load(opt.savefile); 00208 // if isfield(t.opt,'time'); 00209 // opt.time = t.opt.time; 00210 // end 00211 //else 00212 // fprintf('Could not load %s. Starting from scratch.\n', opt.savefile); 00213 //end 00214 //%try 00215 //% t = load(opt.savefile); 00216 //% if isfield(t.opt,'time') 00217 //% opt.time = t.opt.time; 00218 //% end 00219 //%catch 00220 //% fprintf('Could not load %s. Starting from scratch.\n', opt.savefile); 00221 //%end 00222 00223 GurlsOptionsList *timelist; 00224 00225 if (opt.hasOpt("time")) 00226 timelist = GurlsOptionsList::dynacast(opt.getOpt("time")); 00227 else 00228 { 00229 timelist = new GurlsOptionsList("elapsedtime"); 00230 opt.addOpt("time", timelist); 00231 } 00232 00233 00234 // std::vector <double> process_time(seq->size(), 0.0); 00235 gMat2D<T>* process_time_vector = new gMat2D<T>(1, seq->size()); 00236 T *process_time = process_time_vector->getData(); 00237 set(process_time, (T)0.0, seq->size()); 00238 00239 //%for i = 1:numel(opt.process) % Go by the length of process. 00240 //opt.time{jobid} = struct; 00241 //%end 00242 00243 std::string reg1; 00244 std::string reg2; 00245 // std::string fun(""); 00246 std::cout << std::endl 00247 <<"####### New task sequence... " 00248 << std::endl; 00249 00250 for (unsigned long i = 0; i < seq->size(); ++i) 00251 { 00252 seq->getTaskAt(i, reg1, reg2); 00253 00254 std::cout << "\t" << "[Task " << i << ": " 00255 << reg1 << "]: " << reg2 << "... "; 00256 std::cout.flush(); 00257 00258 00259 // switch ( static_cast<int>(process[i]) ) 00260 switch( (*process)[i] ) 00261 { 00262 case GURLS::ignore: 00263 std::cout << " ignored." << std::endl; 00264 break; 00265 00266 case GURLS::compute: 00267 case GURLS::computeNsave: 00268 // WARNING: we should consider the case in which 00269 // the following statements holds true because the 00270 // field reg{1} already exists in opt. 00271 // case {CPT, CSV, ~isfield(opt,reg{1})} 00272 00273 begin = boost::posix_time::microsec_clock::local_time(); 00274 00275 if (!reg1.compare("optimizer")) 00276 { 00277 taskOpt = Optimizer<T>::factory(reg2); 00278 GurlsOption* ret = taskOpt->execute(X, y, opt); 00279 opt.removeOpt("optimizer"); 00280 opt.addOpt("optimizer", ret); 00281 00282 delete taskOpt; 00283 } 00284 else if (!reg1.compare("paramsel")) 00285 { 00286 taskParSel = ParamSelection<T>::factory(reg2); 00287 GurlsOption* ret = taskParSel->execute(X, y, opt); 00288 opt.removeOpt("paramsel"); 00289 opt.addOpt("paramsel", ret); 00290 00291 delete taskParSel; 00292 } 00293 else if (!reg1.compare("pred")) 00294 { 00295 taskPrediction = Prediction<T>::factory(reg2); 00296 GurlsOption* ret = taskPrediction->execute(X, y, opt); 00297 opt.removeOpt("pred"); 00298 opt.addOpt("pred", ret); 00299 00300 delete taskPrediction; 00301 } 00302 else if (!reg1.compare("perf")) 00303 { 00304 taskPerformance = Performance<T>::factory(reg2); 00305 GurlsOption* ret = taskPerformance->execute(X, y, opt); 00306 opt.removeOpt("perf"); 00307 opt.addOpt("perf", ret); 00308 00309 delete taskPerformance; 00310 } 00311 else if (!reg1.compare("kernel")) 00312 { 00313 taskKernel = Kernel<T>::factory(reg2); 00314 GurlsOption* ret = taskKernel->execute(X, y, opt); 00315 opt.removeOpt("kernel"); 00316 opt.addOpt("kernel", ret); 00317 00318 delete taskKernel; 00319 } 00320 else if (!reg1.compare("norm")) 00321 { 00322 taskNorm = Norm<T>::factory(reg2); 00323 GurlsOption* ret = taskNorm->execute(X, y, opt); 00324 opt.removeOpt("norm"); 00325 opt.addOpt("norm", ret); 00326 00327 delete taskNorm; 00328 } 00329 else if (!reg1.compare("split")) 00330 { 00331 taskSplit = Split<T>::factory(reg2); 00332 GurlsOption* ret = taskSplit->execute(X, y, opt); 00333 opt.removeOpt("split"); 00334 opt.addOpt("split", ret); 00335 00336 delete taskSplit; 00337 } 00338 else if (!reg1.compare("predkernel")) 00339 { 00340 taskPredKernel = PredKernel<T>::factory(reg2); 00341 GurlsOption* ret = taskPredKernel->execute(X, y, opt); 00342 opt.removeOpt("predkernel"); 00343 opt.addOpt("predkernel", ret); 00344 00345 delete taskPredKernel; 00346 } 00347 else if (!reg1.compare("conf")) 00348 { 00349 taskConfidence = Confidence<T>::factory(reg2); 00350 GurlsOption* ret = taskConfidence->execute(X, y, opt); 00351 opt.removeOpt("conf"); 00352 opt.addOpt("conf", ret); 00353 00354 delete taskConfidence; 00355 } 00356 00357 // fun = reg1; 00358 // fun+="_"; 00359 // fun+=reg2; 00360 // opt.addOpt(reg1, new OptString(fun)); 00361 00362 end = boost::posix_time::microsec_clock::local_time(); 00363 diff = end-begin; 00364 00365 process_time[i] = ((T)diff.total_milliseconds())/1000.0; 00366 00367 // fName = [reg{1} '_' reg{2}]; 00368 // fun = str2func(fName); 00369 // tic; 00370 // opt = setfield(opt, reg{1}, fun(X, y, opt)); 00371 // opt.time{jobid} = setfield(opt.time{jobid},reg{1}, toc); 00372 // fprintf('\tdone\n'); 00373 std::cout << " done." << std::endl; 00374 break; 00375 00376 case GURLS::load: 00377 // case LDF, 00378 // if exist('t','var') && isfield (t.opt, reg{1}) 00379 // opt = setfield(opt, reg{1}, getfield(t.opt, reg{1})); 00380 // fprintf('\tcopied\n'); 00381 // else 00382 // fprintf('\tcopy failed\n'); 00383 // end 00384 // std::cout << " skipped." << std::endl; 00385 00386 if(loadOpt == NULL) 00387 throw gException("Opt savefile not found"); 00388 if(!loadOpt->hasOpt(reg1)) 00389 { 00390 std::string s = "Task " + reg1 + " not found in opt savefile"; 00391 gException e(s); 00392 throw e; 00393 } 00394 00395 opt.removeOpt(reg1); 00396 tmpOpt = loadOpt->getOpt(reg1); 00397 loadOpt->removeOpt(reg1, false); 00398 opt.addOpt(reg1, tmpOpt); 00399 std::cout << " copied" << std::endl; 00400 00401 break; 00402 default: 00403 throw gException("Unknown task assignment"); 00404 } 00405 00406 } 00407 00408 // timelist->addOpt(processid, new OptNumberList(process_time)); 00409 timelist->addOpt(processid, new OptMatrix<gMat2D<T> >(*process_time_vector)); 00410 00411 //fprintf('\nSave cycle...\n'); 00412 //% Delete whats not necessary 00413 //for i = 1:numel(process) 00414 // fprintf('[Job %d: %15s] %15s: ',jobid, reg{1}, reg{2}); 00415 // reg = regexp(seq{i},':','split'); 00416 // switch process(i) 00417 // case {CSV, LDF} 00418 // fprintf('\tsaving..\n'); 00419 // otherwise 00420 // if isfield (opt, reg{1}) 00421 // opt = rmfield(opt, reg{1}); 00422 // fprintf('\tremoving..\n'); 00423 // else 00424 // fprintf('\tnot found..\n'); 00425 // end 00426 // end 00427 //end 00428 //save(opt.savefile, 'opt', '-v7.3'); 00429 //fprintf('Saving opt in %s\n', opt.savefile); 00430 00431 bool save = false; 00432 00433 std::cout << std::endl << "Save cycle..." << std::endl; 00434 for (unsigned long i = 0; i < seq->size(); ++i) 00435 { 00436 seq->getTaskAt(i, reg1, reg2); 00437 std::cout << "\t" << "[Task " << i << ": " << reg1 << "]: " << reg2 << "... "; 00438 std::cout.flush(); 00439 00440 switch ( (*process)[i] ) 00441 { 00442 case GURLS::ignore: 00443 case GURLS::compute: 00444 case GURLS::remove: 00445 std::cout << "not saved" << std::endl; 00446 opt.removeOpt(reg1); 00447 break; 00448 case GURLS::load: 00449 case GURLS::computeNsave: 00450 std::cout << " saving" << std::endl; 00451 save = true; 00452 break; 00453 } 00454 } 00455 00456 if(save) 00457 { 00458 std::cout << std::endl << "Saving opt in " << saveFile << std::endl; 00459 opt.save(saveFile); 00460 } 00461 00462 delete loadOpt; 00463 00464 // } 00465 // catch (gException& gex) 00466 // { 00467 // throw gex; 00468 // } 00469 00470 } 00471 00472 00473 } 00474 00475 #include "gurls.hpp" 00476 #include "calibratesgd.h" 00477 00478 #endif // _GURLS_GURLS_H_