GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
kernelrlswrapper.hpp
00001 #include "gurls++/kernelrlswrapper.h"
00002 
00003 #include "gurls++/gurls.h"
00004 #include "gurls++/predkerneltraintest.h"
00005 #include "gurls++/primal.h"
00006 #include "gurls++/dual.h"
00007 #include "gurls++/exceptions.h"
00008 
00009 namespace gurls
00010 {
00011 
00012 template <typename T>
00013 KernelRLSWrapper<T>::KernelRLSWrapper(const std::string &name): KernelWrapper<T>(name) { }
00014 
00015 template <typename T>
00016 void KernelRLSWrapper<T>::train(const gMat2D<T> &X, const gMat2D<T> &y)
00017 {
00018     this->opt->removeOpt("split");
00019     this->opt->removeOpt("optimizer");
00020 
00021 
00022     const unsigned long nlambda = static_cast<unsigned long>(this->opt->getOptAsNumber("nlambda"));
00023     const unsigned long nsigma = static_cast<unsigned long>(this->opt->getOptAsNumber("nsigma"));
00024 
00025     OptTaskSequence *seq = new OptTaskSequence();
00026     GurlsOptionsList * process = new GurlsOptionsList("processes", false);
00027     OptProcess* process1 = new OptProcess();
00028     process->addOpt("one", process1);
00029     this->opt->addOpt("seq", seq);
00030     this->opt->addOpt("processes", process);
00031 
00032     if(this->kType == KernelWrapper<T>::LINEAR)
00033     {
00034         if(nlambda > 1ul)
00035         {
00036             *seq << "split:ho" << "kernel:linear" << "paramsel:hodual";
00037             *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave;
00038         }
00039         else if(nlambda == 1ul)
00040         {
00041             if(this->opt->hasOpt("paramsel.lambdas"))
00042             {
00043                 *seq << "kernel:linear";
00044                 *process1 << GURLS::computeNsave;
00045              }
00046             else
00047                 throw gException("Please set a valid value for the regularization parameter, calling setParam(value)");
00048         }
00049         else
00050             throw gException("Please set a valid value for NParam, calling setNParam(value)");
00051 
00052     }
00053     else if(this->kType == KernelWrapper<T>::RBF)
00054     {
00055         if(nlambda > 1ul)
00056         {
00057             if(nsigma > 1ul)
00058             {
00059                 *seq << "split:ho" << "paramsel:siglamho" << "kernel:rbf";
00060                 *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave;
00061             }
00062             else if(nsigma == 1ul)
00063             {
00064                 if(this->opt->hasOpt("paramsel.sigma"))
00065                 {
00066                     *seq << "split:ho" << "kernel:rbf" << "paramsel:hodual";
00067                     *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave;
00068                 }
00069                 else
00070                     throw gException("Please set a valid value for the kernel parameter, calling setSigma(value)");
00071             }
00072             else
00073                 throw gException("Please set a valid value for NSigma, calling setNSigma(value)");
00074         }
00075         else if(nlambda == 1ul)
00076         {
00077             if(nsigma == 1ul)
00078             {
00079                 if(this->opt->hasOpt("paramsel.sigma") && this->opt->hasOpt("paramsel.lambdas"))
00080                 {
00081                     *seq << "split:ho" << "kernel:rbf";
00082                     *process1 << GURLS::computeNsave << GURLS::computeNsave;
00083                 }
00084                 else
00085                     throw gException("Please set a valid value for kernel and regularization parameters, calling setParam(value) and setSigma(value)");
00086             }
00087             else
00088                 throw gException("Please set a valid value for NSigma, calling setNSigma(value)");
00089         }
00090         else
00091             throw gException("Please set a valid value for NParam, calling setNParam(value)");
00092     }
00093 
00094     *seq << "optimizer:rlsdual";
00095     *process1 << GURLS::computeNsave;
00096 
00097     GURLS G;
00098     G.run(X, y, *(this->opt), "one");
00099 
00100 }
00101 
00102 template <typename T>
00103 gMat2D<T>* KernelRLSWrapper<T>::eval(const gMat2D<T> &X)
00104 {
00105     Prediction<T> *pred;
00106     PredKernelTrainTest<T> predkTrainTest;
00107 
00108     gMat2D<T> empty;
00109 
00110     switch (this->kType)
00111     {
00112     case KernelWrapper<T>::LINEAR:
00113         pred = new PredPrimal<T>();
00114         break;
00115     case KernelWrapper<T>::RBF:
00116         pred = new PredDual<T>();
00117         this->opt->removeOpt("predkernel");
00118         this->opt->addOpt("predkernel", predkTrainTest.execute(X, empty, *(this->opt)));
00119     }
00120 
00121     OptMatrix<gMat2D<T> >* result = OptMatrix<gMat2D<T> >::dynacast(pred->execute(X, empty, *(this->opt)));
00122     result->detachValue();
00123     delete pred;
00124 
00125     gMat2D<T>* ret = &(result->getValue());
00126 
00127     delete result;
00128     return ret;
00129 }
00130 
00131 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends