![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 00002 #include "gurls++/gurls.h" 00003 #include "gurls++/optlist.h" 00004 #include "gurls++/wrapper.h" 00005 00006 namespace gurls 00007 { 00008 00009 template <typename T> 00010 GurlsWrapper<T>::GurlsWrapper(const std::string& name):opt(NULL), name(name) 00011 { 00012 opt = new GurlsOptionsList(name, true); 00013 00014 setSplitProportion(0.2); 00015 setNparams(20); 00016 setProblemType(CLASSIFICATION); 00017 } 00018 00019 template <typename T> 00020 GurlsWrapper<T>::~GurlsWrapper() 00021 { 00022 delete opt; 00023 } 00024 00025 template <typename T> 00026 T GurlsWrapper<T>::eval(const gVec<T> &X, unsigned long *index) 00027 { 00028 if(!trainedModel()) 00029 throw gException("Error, Train Model First"); 00030 00031 gMat2D<T>X_mat(1, X.getSize()); 00032 copy(X_mat.getData(), X.getData(), X.getSize()); 00033 00034 gMat2D<T>* pred_mat = eval(X_mat); 00035 00036 const T* pred = pred_mat->getData(); 00037 const unsigned long size = pred_mat->getSize(); 00038 00039 const T* max = std::max_element(pred, pred+size); 00040 T ret = *max; 00041 if(index != NULL) 00042 *index = max-pred; 00043 00044 delete pred_mat; 00045 return ret; 00046 } 00047 00048 template <typename T> 00049 const GurlsOptionsList &GurlsWrapper<T>::getOpt() const 00050 { 00051 return *opt; 00052 } 00053 00054 template <typename T> 00055 void GurlsWrapper<T>::saveModel(const std::string &fileName) 00056 { 00057 opt->save(fileName); 00058 } 00059 00060 template <typename T> 00061 void GurlsWrapper<T>::loadModel(const std::string &fileName) 00062 { 00063 opt->load(fileName); 00064 } 00065 00066 template <typename T> 00067 void GurlsWrapper<T>::setNparams(unsigned long value) 00068 { 00069 opt->getOptValue<OptNumber>("nlambda") = value; 00070 00071 if(opt->hasOpt("paramsel.lambdas") && value > 1.0) 00072 { 00073 std::cout << "Warning: ignoring previous values of the regularization parameter" << std::endl; 00074 opt->getOptAs<GurlsOptionsList>("paramsel")->removeOpt("lambdas"); 00075 } 00076 } 00077 00078 template <typename T> 00079 void GurlsWrapper<T>::setParam(double value) 00080 { 00081 if(!opt->hasOpt("paramsel")) 00082 opt->addOpt("paramsel", new GurlsOptionsList("paramsel")); 00083 00084 if(opt->hasOpt("paramsel.lambdas")) 00085 opt->getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas").getData()[0] = (T)value; 00086 else 00087 { 00088 gMat2D<T> * lambdas = new gMat2D<T>(1,1); 00089 lambdas->getData()[0] = (T)value; 00090 opt->getOptAs<GurlsOptionsList>("paramsel")->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*lambdas)); 00091 } 00092 00093 setNparams(1); 00094 } 00095 00096 template <typename T> 00097 void GurlsWrapper<T>::setSplitProportion(double value) 00098 { 00099 opt->getOptValue<OptNumber>("hoproportion") = value; 00100 } 00101 00102 template <typename T> 00103 void GurlsWrapper<T>::setProblemType(typename GurlsWrapper::ProblemType value) 00104 { 00105 probType = value; 00106 00107 opt->getOptValue<OptString>("hoperf") = (value == CLASSIFICATION)? "macroavg": "rmse"; 00108 } 00109 00110 template <typename T> 00111 bool GurlsWrapper<T>::trainedModel() 00112 { 00113 return opt->hasOpt("optimizer"); 00114 } 00115 00116 00117 template <typename T> 00118 KernelWrapper<T>::KernelWrapper(const std::string &name): GurlsWrapper<T>(name), kType(RBF) 00119 { 00120 // GurlsOptionsList *kernel = new GurlsOptionsList("kernel"); 00121 // this->opt->addOpt("kernel", kernel); 00122 // kernel->addOpt("type", "rbf"); 00123 00124 GurlsOptionsList *paramsel = new GurlsOptionsList("paramsel"); 00125 this->opt->addOpt("paramsel", paramsel); 00126 } 00127 00128 template <typename T> 00129 void KernelWrapper<T>::setKernelType(typename KernelWrapper::KernelType value) 00130 { 00131 kType = value; 00132 00133 // std::string &type = this->opt->template getOptValue<OptString>("kernel.type"); 00134 00135 // switch(value) 00136 // { 00137 // case RBF: 00138 // type = std::string("rbf"); 00139 // case LINEAR: 00140 // type = std::string("linear"); 00141 // case CHISQUARED: 00142 // type = std::string("chisquared"); 00143 // } 00144 } 00145 00146 template <typename T> 00147 void KernelWrapper<T>::setSigma(double value) 00148 { 00149 if(this->opt->hasOpt("paramsel.sigma")) 00150 this->opt->template getOptValue<OptNumber>("paramsel.sigma") = value; 00151 else 00152 { 00153 GurlsOptionsList* paramsel = this->opt->template getOptAs<GurlsOptionsList>("paramsel"); 00154 paramsel->addOpt("sigma", new OptNumber(value)); 00155 } 00156 00157 setNSigma(1); 00158 } 00159 00160 template <typename T> 00161 void KernelWrapper<T>::setNSigma(unsigned long value) 00162 { 00163 this->opt->template getOptValue<OptNumber>("nsigma") = value; 00164 00165 if(this->opt->hasOpt("paramsel.sigma") && value > 1.0) 00166 { 00167 std::cout << "Warning: ignoring previous values of the kernel parameter" << std::endl; 00168 this->opt->template getOptAs<GurlsOptionsList>("paramsel")->removeOpt("sigma"); 00169 } 00170 } 00171 00172 }