GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
wrapper.hpp
00001 
00002 #include "gurls++/gurls.h"
00003 #include "gurls++/optlist.h"
00004 #include "gurls++/wrapper.h"
00005 
00006 namespace gurls
00007 {
00008 
00009 template <typename T>
00010 GurlsWrapper<T>::GurlsWrapper(const std::string& name):opt(NULL), name(name)
00011 {
00012     opt = new GurlsOptionsList(name, true);
00013 
00014     setSplitProportion(0.2);
00015     setNparams(20);
00016     setProblemType(CLASSIFICATION);
00017 }
00018 
00019 template <typename T>
00020 GurlsWrapper<T>::~GurlsWrapper()
00021 {
00022     delete opt;
00023 }
00024 
00025 template <typename T>
00026 T GurlsWrapper<T>::eval(const gVec<T> &X, unsigned long *index)
00027 {
00028     if(!trainedModel())
00029         throw gException("Error, Train Model First");
00030 
00031     gMat2D<T>X_mat(1, X.getSize());
00032     copy(X_mat.getData(), X.getData(), X.getSize());
00033 
00034     gMat2D<T>* pred_mat = eval(X_mat);
00035 
00036     const T* pred = pred_mat->getData();
00037     const unsigned long size = pred_mat->getSize();
00038 
00039     const T* max = std::max_element(pred, pred+size);
00040     T ret = *max;
00041     if(index != NULL)
00042         *index = max-pred;
00043 
00044     delete pred_mat;
00045     return ret;
00046 }
00047 
00048 template <typename T>
00049 const GurlsOptionsList &GurlsWrapper<T>::getOpt() const
00050 {
00051     return *opt;
00052 }
00053 
00054 template <typename T>
00055 void GurlsWrapper<T>::saveModel(const std::string &fileName)
00056 {
00057     opt->save(fileName);
00058 }
00059 
00060 template <typename T>
00061 void GurlsWrapper<T>::loadModel(const std::string &fileName)
00062 {
00063     opt->load(fileName);
00064 }
00065 
00066 template <typename T>
00067 void GurlsWrapper<T>::setNparams(unsigned long value)
00068 {
00069     opt->getOptValue<OptNumber>("nlambda") = value;
00070 
00071     if(opt->hasOpt("paramsel.lambdas") && value > 1.0)
00072     {
00073         std::cout << "Warning: ignoring previous values of the regularization parameter" << std::endl;
00074         opt->getOptAs<GurlsOptionsList>("paramsel")->removeOpt("lambdas");
00075     }
00076 }
00077 
00078 template <typename T>
00079 void GurlsWrapper<T>::setParam(double value)
00080 {
00081     if(!opt->hasOpt("paramsel"))
00082         opt->addOpt("paramsel", new GurlsOptionsList("paramsel"));
00083 
00084     if(opt->hasOpt("paramsel.lambdas"))
00085         opt->getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas").getData()[0] = (T)value;
00086     else
00087     {
00088         gMat2D<T> * lambdas = new gMat2D<T>(1,1);
00089         lambdas->getData()[0] = (T)value;
00090         opt->getOptAs<GurlsOptionsList>("paramsel")->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*lambdas));
00091     }
00092 
00093     setNparams(1);
00094 }
00095 
00096 template <typename T>
00097 void GurlsWrapper<T>::setSplitProportion(double value)
00098 {
00099     opt->getOptValue<OptNumber>("hoproportion") = value;
00100 }
00101 
00102 template <typename T>
00103 void GurlsWrapper<T>::setProblemType(typename GurlsWrapper::ProblemType value)
00104 {
00105     probType = value;
00106 
00107     opt->getOptValue<OptString>("hoperf") = (value == CLASSIFICATION)? "macroavg": "rmse";
00108 }
00109 
00110 template <typename T>
00111 bool GurlsWrapper<T>::trainedModel()
00112 {
00113     return opt->hasOpt("optimizer");
00114 }
00115 
00116 
00117 template <typename T>
00118 KernelWrapper<T>::KernelWrapper(const std::string &name): GurlsWrapper<T>(name), kType(RBF)
00119 {
00120 //    GurlsOptionsList *kernel = new GurlsOptionsList("kernel");
00121 //    this->opt->addOpt("kernel", kernel);
00122 //    kernel->addOpt("type", "rbf");
00123 
00124     GurlsOptionsList *paramsel = new GurlsOptionsList("paramsel");
00125     this->opt->addOpt("paramsel", paramsel);
00126 }
00127 
00128 template <typename T>
00129 void KernelWrapper<T>::setKernelType(typename KernelWrapper::KernelType value)
00130 {
00131     kType = value;
00132 
00133 //    std::string &type = this->opt->template getOptValue<OptString>("kernel.type");
00134 
00135 //    switch(value)
00136 //    {
00137 //    case RBF:
00138 //        type = std::string("rbf");
00139 //    case LINEAR:
00140 //        type = std::string("linear");
00141 //    case CHISQUARED:
00142 //        type = std::string("chisquared");
00143 //    }
00144 }
00145 
00146 template <typename T>
00147 void KernelWrapper<T>::setSigma(double value)
00148 {
00149     if(this->opt->hasOpt("paramsel.sigma"))
00150         this->opt->template getOptValue<OptNumber>("paramsel.sigma") = value;
00151     else
00152     {
00153         GurlsOptionsList* paramsel = this->opt->template getOptAs<GurlsOptionsList>("paramsel");
00154         paramsel->addOpt("sigma", new OptNumber(value));
00155     }
00156 
00157     setNSigma(1);
00158 }
00159 
00160 template <typename T>
00161 void KernelWrapper<T>::setNSigma(unsigned long value)
00162 {
00163     this->opt->template getOptValue<OptNumber>("nsigma") = value;
00164 
00165     if(this->opt->hasOpt("paramsel.sigma") && value > 1.0)
00166     {
00167         std::cout << "Warning: ignoring previous values of the kernel parameter" << std::endl;
00168         this->opt->template getOptAs<GurlsOptionsList>("paramsel")->removeOpt("sigma");
00169     }
00170 }
00171 
00172 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends