![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 #include "gurls++/randfeatswrapper.h" 00002 00003 #include "gurls++/splitho.h" 00004 #include "gurls++/hoprimal.h" 00005 #include "gurls++/rlsprimal.h" 00006 #include "gurls++/primal.h" 00007 #include "gurls++/options.h" 00008 #include "gurls++/utils.h" 00009 00010 00011 namespace gurls 00012 { 00013 00014 template <typename T> 00015 RandomFeaturesWrapper<T>::RandomFeaturesWrapper(const std::string &name): RLSWrapper<T>(name), W(NULL) { } 00016 00017 template <typename T> 00018 RandomFeaturesWrapper<T>::~RandomFeaturesWrapper() 00019 { 00020 if(W != NULL) 00021 delete W; 00022 } 00023 00024 template <typename T> 00025 void RandomFeaturesWrapper<T>::train(const gMat2D<T> &X, const gMat2D<T> &y) 00026 { 00027 // D = opt.n_randfeats; 00028 const unsigned long D = static_cast<unsigned long>(this->opt->getOptAsNumber("randfeats.D")); 00029 00030 // [n,d] = size(Xtr); 00031 const unsigned long d = X.cols(); 00032 00033 00034 // W = sqrt(2)*randn(d,D); 00035 00036 if(W != NULL) 00037 delete W; 00038 00039 W = rp_projections<T>(d, D); 00040 00041 00042 // V = X*W; 00043 // Xtr = [cos(V) sin(V)]; 00044 gMat2D<T> *Xtr = rp_apply_real(X, *W); 00045 00046 RLSWrapper<T>::train(*Xtr, y); 00047 00048 delete Xtr; 00049 } 00050 00051 template<typename T> 00052 gMat2D<T>* RandomFeaturesWrapper<T>::eval(const gMat2D<T> &X) 00053 { 00054 if(W == NULL) 00055 throw gException("Error, Train Model First"); 00056 00057 00058 // V = X*W; 00059 // Xte = [cos(V) sin(V)]; 00060 gMat2D<T> *Xte = rp_apply_real(X, *W); 00061 00062 gMat2D<T>* pred = RLSWrapper<T>::eval(*Xte); 00063 00064 delete Xte; 00065 00066 return pred; 00067 } 00068 00069 template<typename T> 00070 void RandomFeaturesWrapper<T>::setNRandFeats(unsigned long value) 00071 { 00072 this->opt->template getOptValue<OptNumber>("randfeats.D") = value; 00073 } 00074 00075 }