GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
gprwrapper.hpp
00001 #include "gurls++/gprwrapper.h"
00002 
00003 #include "gurls++/gurls.h"
00004 #include "gurls++/predkerneltraintest.h"
00005 #include "gurls++/dual.h"
00006 #include "gurls++/exceptions.h"
00007 #include "gurls++/normzscore.h"
00008 #include "gurls++/normtestzscore.h"
00009 
00010 namespace gurls
00011 {
00012 
00013 template <typename T>
00014 GPRWrapper<T>::GPRWrapper(const std::string &name): KernelWrapper<T>(name), norm(NULL) { }
00015 
00016 template <typename T>
00017 GPRWrapper<T>::~GPRWrapper()
00018 {
00019     if(norm != NULL)
00020         delete norm;
00021 }
00022 
00023 template <typename T>
00024 void GPRWrapper<T>::train(const gMat2D<T> &X, const gMat2D<T> &y)
00025 {
00026     this->opt->removeOpt("split");
00027     this->opt->removeOpt("optimizer");
00028 
00029     if(norm != NULL)
00030     {
00031         delete norm;
00032         norm = NULL;
00033     }
00034 
00035     NormZScore<T> taskNorm;
00036 
00037     norm = taskNorm.execute(X, y, *(this->opt));
00038 
00039 
00040 
00041     const unsigned long nlambda = static_cast<unsigned long>(this->opt->getOptAsNumber("nlambda"));
00042     const unsigned long nsigma = static_cast<unsigned long>(this->opt->getOptAsNumber("nsigma"));
00043 
00044 
00045     OptTaskSequence *seq = new OptTaskSequence();
00046     GurlsOptionsList * process = new GurlsOptionsList("processes", false);
00047     OptProcess* process1 = new OptProcess();
00048     process->addOpt("one", process1);
00049     this->opt->addOpt("seq", seq);
00050     this->opt->addOpt("processes", process);
00051 //    this->opt->printAll();
00052 
00053     if(this->kType == KernelWrapper<T>::LINEAR)
00054     {
00055         if(nlambda > 1ul)
00056         {
00057             *seq << "split:ho" << "kernel:linear" << "paramsel:hogpregr";
00058             *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave;
00059         }
00060         else if(nlambda == 1ul)
00061         {
00062             if(this->opt->hasOpt("paramsel.lambdas"))
00063             {
00064                 *seq << "kernel:linear";
00065                 *process1 << GURLS::computeNsave;
00066             }
00067             else
00068                 throw gException("Please set a valid value for the regularization parameter, calling setParam(value)");
00069         }
00070         else
00071             throw gException("Please set a valid value for NParam, calling setNParam(value)");
00072     }
00073     else if(this->kType == KernelWrapper<T>::RBF)
00074     {
00075         if(nlambda > 1ul)
00076         {
00077             if(nsigma > 1ul)
00078             {
00079                 *seq << "split:ho" << "paramsel:siglamhogpregr" << "kernel:rbf";
00080                 *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave;
00081             }
00082             else if(nsigma == 1ul)
00083             {
00084                 if(this->opt->hasOpt("paramsel.sigma"))
00085                 {
00086                     *seq << "split:ho" << "kernel:rbf" << "paramsel:hogpregr";
00087                     *process1 << GURLS::computeNsave << GURLS::computeNsave << GURLS::computeNsave;
00088                 }
00089                 else
00090                     throw gException("Please set a valid value for the kernel parameter, calling setSigma(value)");
00091             }
00092             else
00093                 throw gException("Please set a valid value for NSigma, calling setNSigma(value)");
00094         }
00095         else if(nlambda == 1ul)
00096         {
00097             if(nsigma == 1ul)
00098             {
00099                 if(this->opt->hasOpt("paramsel.sigma") && this->opt->hasOpt("paramsel.lambdas"))
00100                 {
00101                     *seq << "kernel:rbf";
00102                     *process1 << GURLS::computeNsave;
00103                 }
00104                 else
00105                     throw gException("Please set a valid value for kernel and regularization parameters, calling setParam(value) and setSigma(value)");
00106             }
00107             else
00108                 throw gException("Please set a valid value for NSigma, calling setNSigma(value)");
00109         }
00110         else
00111             throw gException("Please set a valid value for NParam, calling setNParam(value)");
00112     }
00113 
00114     *seq << "optimizer:rlsgpregr";
00115     *process1 << GURLS::computeNsave;
00116 
00117     GURLS G;
00118     G.run(norm->getOptValue<OptMatrix<gMat2D<T> > >("X"), norm->getOptValue<OptMatrix<gMat2D<T> > >("Y"), *(this->opt), "one");
00119 
00120 }
00121 
00122 template <typename T>
00123 gMat2D<T>* GPRWrapper<T>::eval(const gMat2D<T> &X)
00124 {
00125     gMat2D<T> vars;
00126 
00127     return eval(X, vars);
00128 }
00129 
00130 template <typename T>
00131 gMat2D<T>* GPRWrapper<T>::eval(const gMat2D<T> &X, gMat2D<T> &vars)
00132 {
00133     gMat2D<T> empty;
00134 
00135     NormTestZScore<T> normTask;
00136     PredGPRegr<T> predTask;
00137 
00138 
00139     GurlsOptionsList *normX = normTask.execute(X, empty, *(norm));
00140 
00141     gMat2D<T> &Xresc = normX->getOptValue<OptMatrix<gMat2D<T> > >("X");
00142 
00143     PredKernelTrainTest<T> predkTrainTest;
00144     this->opt->removeOpt("predkernel");
00145     this->opt->addOpt("predkernel", predkTrainTest.execute(Xresc, empty, *(this->opt)));
00146 
00147     GurlsOptionsList *pred = predTask.execute(Xresc, empty, *(this->opt));
00148 
00149     delete normX;
00150 
00151     OptMatrix<gMat2D<T> >* pmeans = pred->getOptAs<OptMatrix<gMat2D<T> > >("means");
00152     pmeans->detachValue();
00153 
00154     gMat2D<T> &predMeans = pmeans->getValue();
00155     gMat2D<T> &predVars = pred->getOptValue<OptMatrix<gMat2D<T> > >("vars");
00156 
00157     const unsigned long n = predMeans.rows();
00158     const unsigned long t = predMeans.cols();
00159 
00160     T* column = predMeans.getData();
00161     const T* std_it = norm->getOptValue<OptMatrix<gMat2D<T> > >("stdY").getData();
00162     const T* mean_it = norm->getOptValue<OptMatrix<gMat2D<T> > >("meanY").getData();
00163     const T* pvars_it = predVars.getData();
00164 
00165     vars.resize(n, t);
00166 
00167     T* vars_it = vars.getData();
00168 
00169     for(unsigned long i=0; i<t; ++i, column+=n, ++std_it, ++mean_it, vars_it+=n)
00170     {
00171         scal(n, *std_it, column, 1);
00172         axpy(n, (T)1.0, mean_it, 0, column, 1);
00173 
00174         copy(vars_it, pvars_it, n);
00175         scal(n, (*std_it)*(*std_it), vars_it, 1);
00176     }
00177 
00178     delete pred;
00179 
00180     return &predMeans;
00181 }
00182 
00183 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends