![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 #include "gurls++/kernelrlswrapper.h" 00002 00003 #include "gurls++/gurls.h" 00004 #include "gurls++/predkerneltraintest.h" 00005 #include "gurls++/primal.h" 00006 #include "gurls++/dual.h" 00007 #include "gurls++/exceptions.h" 00008 00009 namespace gurls 00010 { 00011 00012 template <typename T> 00013 KernelRLSWrapper<T>::KernelRLSWrapper(const std::string &name): KernelWrapper<T>(name) { } 00014 00015 template <typename T> 00016 void KernelRLSWrapper<T>::train(const gMat2D<T> &X, const gMat2D<T> &y) 00017 { 00018 this->opt->removeOpt("split"); 00019 this->opt->removeOpt("optimizer"); 00020 00021 00022 const unsigned long nlambda = static_cast<unsigned long>(this->opt->getOptAsNumber("nlambda")); 00023 const unsigned long nsigma = static_cast<unsigned long>(this->opt->getOptAsNumber("nsigma")); 00024 00025 OptTaskSequence *seq = new OptTaskSequence(); 00026 GurlsOptionsList * process = new GurlsOptionsList("processes", false); 00027 OptProcess* process1 = new OptProcess(); 00028 process->addOpt("one", process1); 00029 this->opt->addOpt("seq", seq); 00030 this->opt->addOpt("processes", process); 00031 00032 if(this->kType == KernelWrapper<T>::LINEAR) 00033 { 00034 if(nlambda > 1ul) 00035 { 00036 *seq << "split:ho" << "kernel:linear" << "paramsel:hodual"; 00037 *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave; 00038 } 00039 else if(nlambda == 1ul) 00040 { 00041 if(this->opt->hasOpt("paramsel.lambdas")) 00042 { 00043 *seq << "kernel:linear"; 00044 *process1 << GURLS::computeNsave; 00045 } 00046 else 00047 throw gException("Please set a valid value for the regularization parameter, calling setParam(value)"); 00048 } 00049 else 00050 throw gException("Please set a valid value for NParam, calling setNParam(value)"); 00051 00052 } 00053 else if(this->kType == KernelWrapper<T>::RBF) 00054 { 00055 if(nlambda > 1ul) 00056 { 00057 if(nsigma > 1ul) 00058 { 00059 *seq << "split:ho" << "paramsel:siglamho" << "kernel:rbf"; 00060 *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave; 00061 } 00062 else if(nsigma == 1ul) 00063 { 00064 if(this->opt->hasOpt("paramsel.sigma")) 00065 { 00066 *seq << "split:ho" << "kernel:rbf" << "paramsel:hodual"; 00067 *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave; 00068 } 00069 else 00070 throw gException("Please set a valid value for the kernel parameter, calling setSigma(value)"); 00071 } 00072 else 00073 throw gException("Please set a valid value for NSigma, calling setNSigma(value)"); 00074 } 00075 else if(nlambda == 1ul) 00076 { 00077 if(nsigma == 1ul) 00078 { 00079 if(this->opt->hasOpt("paramsel.sigma") && this->opt->hasOpt("paramsel.lambdas")) 00080 { 00081 *seq << "split:ho" << "kernel:rbf"; 00082 *process1 << GURLS::computeNsave << GURLS::computeNsave; 00083 } 00084 else 00085 throw gException("Please set a valid value for kernel and regularization parameters, calling setParam(value) and setSigma(value)"); 00086 } 00087 else 00088 throw gException("Please set a valid value for NSigma, calling setNSigma(value)"); 00089 } 00090 else 00091 throw gException("Please set a valid value for NParam, calling setNParam(value)"); 00092 } 00093 00094 *seq << "optimizer:rlsdual"; 00095 *process1 << GURLS::computeNsave; 00096 00097 GURLS G; 00098 G.run(X, y, *(this->opt), "one"); 00099 00100 } 00101 00102 template <typename T> 00103 gMat2D<T>* KernelRLSWrapper<T>::eval(const gMat2D<T> &X) 00104 { 00105 Prediction<T> *pred; 00106 PredKernelTrainTest<T> predkTrainTest; 00107 00108 gMat2D<T> empty; 00109 00110 switch (this->kType) 00111 { 00112 case KernelWrapper<T>::LINEAR: 00113 pred = new PredPrimal<T>(); 00114 break; 00115 case KernelWrapper<T>::RBF: 00116 pred = new PredDual<T>(); 00117 this->opt->removeOpt("predkernel"); 00118 this->opt->addOpt("predkernel", predkTrainTest.execute(X, empty, *(this->opt))); 00119 } 00120 00121 OptMatrix<gMat2D<T> >* result = OptMatrix<gMat2D<T> >::dynacast(pred->execute(X, empty, *(this->opt))); 00122 result->detachValue(); 00123 delete pred; 00124 00125 gMat2D<T>* ret = &(result->getValue()); 00126 00127 delete result; 00128 return ret; 00129 } 00130 00131 }