![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
PerfMacroAvg is the sub-class of Performance that evaluates prediction accuracy.
#include <macroavg.h>
Public Member Functions | |
GurlsOptionsList * | execute (const gMat2D< T > &X, const gMat2D< T > &Y, const GurlsOptionsList &opt) throw (gException) |
Evaluates the average accuracy per class. | |
Static Public Member Functions | |
static Performance< T > * | factory (const std::string &id) throw (BadPerformanceCreation) |
Factory function returning a pointer to the newly created object. | |
Protected Member Functions | |
void | macroavg (const unsigned long *trueY, const unsigned long *predY, const int length, int totClasses, T *&perClass, T ¯oAverage, unsigned long &perClass_length) |
Auxiliary function called by execute method. |
Definition at line 62 of file macroavg.h.
GurlsOptionsList * gurls::PerfMacroAvg< T >::execute | ( | const gMat2D< T > & | X, |
const gMat2D< T > & | Y, | ||
const GurlsOptionsList & | opt | ||
) | throw (gException) [virtual] |
X | input data matrix |
Y | labels matrix |
opt | options with the following:
|
Implements gurls::Performance< T >.
Definition at line 84 of file macroavg.h.
{ const unsigned long rows = Y.rows(); const unsigned long cols = Y.cols(); // if isfield (opt,'perf') // p = opt.perf; % lets not overwrite existing performance measures. // % unless they have the same name // end GurlsOptionsList* perf = NULL; if(opt.hasOpt("perf")) { GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp"); tmp_opt->copyOpt("perf", opt); perf = GurlsOptionsList::dynacast(tmp_opt->getOpt("perf")); tmp_opt->removeOpt("perf", false); delete tmp_opt; perf->removeOpt("acc"); perf->removeOpt("forho"); // perf->removeOpt("forplot"); } else perf = new GurlsOptionsList("perf"); gMat2D<T>* acc_mat = new gMat2D<T>(1, cols); T* acc = acc_mat->getData(); // T = size(y,2); // y_true = y; const T* y_true = Y.getData(); // y_pred = opt.pred; const gMat2D<T> &y_pred = opt.getOptValue<OptMatrix<gMat2D<T> > >("pred"); // if size(y,2) == 1 if(cols == 1) { // predlab = sign(y_pred); T* predLab = sign(y_pred.getData(), y_pred.getSize()); T* tmp = compare<T>(predLab, Y.getData(), rows, &eq); // p.acc = mean(predlab == y); mean(tmp, acc, rows, 1, 1); // p.forho = mean(predlab == y); // p.forplot = mean(predlab == y); delete [] tmp; delete [] predLab; } else { // %% Assumes single label prediction. // [dummy, predlab] = max(y_pred,[],2); T* work = new T[std::max(Y.getSize(), y_pred.getSize() )]; unsigned long* predLab = new unsigned long[rows]; indicesOfMax(y_pred.getData(), rows, y_pred.cols(), predLab, work, 2); // [dummy, truelab] = max(y_true,[],2); // unsigned long* trueLab = indicesOfMax(y_true, rows, cols, 2); unsigned long* trueLab = new unsigned long[rows]; indicesOfMax(y_true, rows, cols, trueLab, work, 2); delete[] work; // [MacroAvg, PerClass] = macroavg(truelab, predlab); T macroAverage; T* perClass; unsigned long perClass_length; macroavg(trueLab, predLab, rows, cols, perClass, macroAverage, perClass_length); if(perClass_length > cols) throw gException(Exception_Inconsistent_Size); delete[] predLab; delete[] trueLab; // for t = 1:length(PerClass), // p.acc(t) = PerClass(t); // p.forho(t) = p.acc(t); // p.forplot(t) = p.acc(t); // end // for t = (length(PerClass)+1):T // p.acc(t) = 0; // p.forho(t) = 0; // p.forplot(t) = 0; // end copy(acc, perClass, perClass_length); if(perClass_length < cols) set(acc+perClass_length, (T)0.0, cols-perClass_length); delete[] perClass; } OptMatrix<gMat2D<T> >* acc_opt = new OptMatrix<gMat2D<T> >(*acc_mat); perf->addOpt("acc", acc_opt); OptMatrix<gMat2D<T> >* forho_opt = new OptMatrix<gMat2D<T> >(*(new gMat2D<T>(*acc_mat))); perf->addOpt("forho", forho_opt); // OptMatrix<gMat2D<T> >* forplot_opt = new OptMatrix<gMat2D<T> >(*(new gMat2D<T>(*acc_mat))); // perf->addOpt("forplot", forplot_opt); return perf; }
static Performance<T>* gurls::Performance< T >::factory | ( | const std::string & | id | ) | throw (BadPerformanceCreation) [inline, static, inherited] |
Definition at line 111 of file perf.h.
{ if(id == "precrec") return new PerfPrecRec<T>; if(id == "macroavg") return new PerfMacroAvg<T>; if(id == "rmse") return new PerfRmse<T>; throw BadPerformanceCreation(id); }