![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 #include "gurls++/gprwrapper.h" 00002 00003 #include "gurls++/gurls.h" 00004 #include "gurls++/predkerneltraintest.h" 00005 #include "gurls++/dual.h" 00006 #include "gurls++/exceptions.h" 00007 #include "gurls++/normzscore.h" 00008 #include "gurls++/normtestzscore.h" 00009 00010 namespace gurls 00011 { 00012 00013 template <typename T> 00014 GPRWrapper<T>::GPRWrapper(const std::string &name): KernelWrapper<T>(name), norm(NULL) { } 00015 00016 template <typename T> 00017 GPRWrapper<T>::~GPRWrapper() 00018 { 00019 if(norm != NULL) 00020 delete norm; 00021 } 00022 00023 template <typename T> 00024 void GPRWrapper<T>::train(const gMat2D<T> &X, const gMat2D<T> &y) 00025 { 00026 this->opt->removeOpt("split"); 00027 this->opt->removeOpt("optimizer"); 00028 00029 if(norm != NULL) 00030 { 00031 delete norm; 00032 norm = NULL; 00033 } 00034 00035 NormZScore<T> taskNorm; 00036 00037 norm = taskNorm.execute(X, y, *(this->opt)); 00038 00039 00040 00041 const unsigned long nlambda = static_cast<unsigned long>(this->opt->getOptAsNumber("nlambda")); 00042 const unsigned long nsigma = static_cast<unsigned long>(this->opt->getOptAsNumber("nsigma")); 00043 00044 00045 OptTaskSequence *seq = new OptTaskSequence(); 00046 GurlsOptionsList * process = new GurlsOptionsList("processes", false); 00047 OptProcess* process1 = new OptProcess(); 00048 process->addOpt("one", process1); 00049 this->opt->addOpt("seq", seq); 00050 this->opt->addOpt("processes", process); 00051 // this->opt->printAll(); 00052 00053 if(this->kType == KernelWrapper<T>::LINEAR) 00054 { 00055 if(nlambda > 1ul) 00056 { 00057 *seq << "split:ho" << "kernel:linear" << "paramsel:hogpregr"; 00058 *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave; 00059 } 00060 else if(nlambda == 1ul) 00061 { 00062 if(this->opt->hasOpt("paramsel.lambdas")) 00063 { 00064 *seq << "kernel:linear"; 00065 *process1 << GURLS::computeNsave; 00066 } 00067 else 00068 throw gException("Please set a valid value for the regularization parameter, calling setParam(value)"); 00069 } 00070 else 00071 throw gException("Please set a valid value for NParam, calling setNParam(value)"); 00072 } 00073 else if(this->kType == KernelWrapper<T>::RBF) 00074 { 00075 if(nlambda > 1ul) 00076 { 00077 if(nsigma > 1ul) 00078 { 00079 *seq << "split:ho" << "paramsel:siglamhogpregr" << "kernel:rbf"; 00080 *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave; 00081 } 00082 else if(nsigma == 1ul) 00083 { 00084 if(this->opt->hasOpt("paramsel.sigma")) 00085 { 00086 *seq << "split:ho" << "kernel:rbf" << "paramsel:hogpregr"; 00087 *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave; 00088 } 00089 else 00090 throw gException("Please set a valid value for the kernel parameter, calling setSigma(value)"); 00091 } 00092 else 00093 throw gException("Please set a valid value for NSigma, calling setNSigma(value)"); 00094 } 00095 else if(nlambda == 1ul) 00096 { 00097 if(nsigma == 1ul) 00098 { 00099 if(this->opt->hasOpt("paramsel.sigma") && this->opt->hasOpt("paramsel.lambdas")) 00100 { 00101 *seq << "kernel:rbf"; 00102 *process1 << GURLS::computeNsave; 00103 } 00104 else 00105 throw gException("Please set a valid value for kernel and regularization parameters, calling setParam(value) and setSigma(value)"); 00106 } 00107 else 00108 throw gException("Please set a valid value for NSigma, calling setNSigma(value)"); 00109 } 00110 else 00111 throw gException("Please set a valid value for NParam, calling setNParam(value)"); 00112 } 00113 00114 *seq << "optimizer:rlsgpregr"; 00115 *process1 << GURLS::computeNsave; 00116 00117 GURLS G; 00118 G.run(norm->getOptValue<OptMatrix<gMat2D<T> > >("X"), norm->getOptValue<OptMatrix<gMat2D<T> > >("Y"), *(this->opt), "one"); 00119 00120 } 00121 00122 template <typename T> 00123 gMat2D<T>* GPRWrapper<T>::eval(const gMat2D<T> &X) 00124 { 00125 gMat2D<T> vars; 00126 00127 return eval(X, vars); 00128 } 00129 00130 template <typename T> 00131 gMat2D<T>* GPRWrapper<T>::eval(const gMat2D<T> &X, gMat2D<T> &vars) 00132 { 00133 gMat2D<T> empty; 00134 00135 NormTestZScore<T> normTask; 00136 PredGPRegr<T> predTask; 00137 00138 00139 GurlsOptionsList *normX = normTask.execute(X, empty, *(norm)); 00140 00141 gMat2D<T> &Xresc = normX->getOptValue<OptMatrix<gMat2D<T> > >("X"); 00142 00143 PredKernelTrainTest<T> predkTrainTest; 00144 this->opt->removeOpt("predkernel"); 00145 this->opt->addOpt("predkernel", predkTrainTest.execute(Xresc, empty, *(this->opt))); 00146 00147 GurlsOptionsList *pred = predTask.execute(Xresc, empty, *(this->opt)); 00148 00149 delete normX; 00150 00151 OptMatrix<gMat2D<T> >* pmeans = pred->getOptAs<OptMatrix<gMat2D<T> > >("means"); 00152 pmeans->detachValue(); 00153 00154 gMat2D<T> &predMeans = pmeans->getValue(); 00155 gMat2D<T> &predVars = pred->getOptValue<OptMatrix<gMat2D<T> > >("vars"); 00156 00157 const unsigned long n = predMeans.rows(); 00158 const unsigned long t = predMeans.cols(); 00159 00160 T* column = predMeans.getData(); 00161 const T* std_it = norm->getOptValue<OptMatrix<gMat2D<T> > >("stdY").getData(); 00162 const T* mean_it = norm->getOptValue<OptMatrix<gMat2D<T> > >("meanY").getData(); 00163 const T* pvars_it = predVars.getData(); 00164 00165 vars.resize(n, t); 00166 00167 T* vars_it = vars.getData(); 00168 00169 for(unsigned long i=0; i<t; ++i, column+=n, ++std_it, ++mean_it, vars_it+=n) 00170 { 00171 scal(n, *std_it, column, 1); 00172 axpy(n, (T)1.0, mean_it, 0, column, 1); 00173 00174 copy(vars_it, pvars_it, n); 00175 scal(n, (*std_it)*(*std_it), vars_it, 1); 00176 } 00177 00178 delete pred; 00179 00180 return &predMeans; 00181 } 00182 00183 }