GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
macroavg.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_MACROAVG_H_
00044 #define _GURLS_MACROAVG_H_
00045 
00046 #include "gurls++/perf.h"
00047 
00048 #include "gurls++/utils.h"
00049 #include "gurls++/gvec.h"
00050 #include "gurls++/optmatrix.h"
00051 
00052 #include "float.h"
00053 
00054 namespace gurls {
00055 
00061 template <typename T>
00062 class PerfMacroAvg: public Performance<T>{
00063 
00064 public:
00077     GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) throw(gException);
00078 
00079 protected:
00080     void macroavg(const unsigned long* trueY, const unsigned long* predY, const int length,int totClasses, T* &perClass, T &macroAverage, unsigned long &perClass_length);
00081 };
00082 
00083 template<typename T>
00084 GurlsOptionsList* PerfMacroAvg<T>::execute(const gMat2D<T>& /*X*/, const gMat2D<T>& Y, const GurlsOptionsList& opt) throw(gException)
00085 {
00086     const unsigned long rows = Y.rows();
00087     const unsigned long cols = Y.cols();
00088 
00089 
00090     //    if isfield (opt,'perf')
00091     //        p = opt.perf; % lets not overwrite existing performance measures.
00092     //                  % unless they have the same name
00093     //    end
00094 
00095     GurlsOptionsList* perf = NULL;
00096 
00097     if(opt.hasOpt("perf"))
00098     {
00099         GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp");
00100         tmp_opt->copyOpt("perf", opt);
00101 
00102         perf = GurlsOptionsList::dynacast(tmp_opt->getOpt("perf"));
00103         tmp_opt->removeOpt("perf", false);
00104         delete tmp_opt;
00105 
00106         perf->removeOpt("acc");
00107         perf->removeOpt("forho");
00108 //        perf->removeOpt("forplot");
00109     }
00110     else
00111         perf = new GurlsOptionsList("perf");
00112 
00113 
00114     gMat2D<T>* acc_mat = new gMat2D<T>(1, cols);
00115     T* acc = acc_mat->getData();
00116 
00117 //    T = size(y,2);
00118 
00119 //    y_true = y;
00120     const T* y_true = Y.getData();
00121 
00122 
00123 //    y_pred = opt.pred;
00124     const gMat2D<T> &y_pred = opt.getOptValue<OptMatrix<gMat2D<T> > >("pred");
00125 
00126 
00127 //    if size(y,2) == 1
00128     if(cols == 1)
00129     {
00130 //        predlab = sign(y_pred);
00131         T* predLab = sign(y_pred.getData(), y_pred.getSize());
00132 
00133         T* tmp = compare<T>(predLab, Y.getData(), rows, &eq);
00134 
00135 //        p.acc = mean(predlab == y);
00136         mean(tmp, acc, rows, 1, 1);
00137 
00138 //        p.forho = mean(predlab == y);
00139 //        p.forplot = mean(predlab == y);
00140 
00141         delete [] tmp;
00142         delete [] predLab;
00143     }
00144     else
00145     {
00146 //        %% Assumes single label prediction.
00147 //        [dummy, predlab] = max(y_pred,[],2);
00148         T* work = new T[std::max(Y.getSize(), y_pred.getSize() )];
00149 
00150         unsigned long* predLab = new unsigned long[rows];
00151         indicesOfMax(y_pred.getData(), rows, y_pred.cols(), predLab, work, 2);
00152 
00153 //        [dummy, truelab] = max(y_true,[],2);
00154 //        unsigned long* trueLab = indicesOfMax(y_true, rows, cols, 2);
00155         unsigned long* trueLab = new unsigned long[rows];
00156         indicesOfMax(y_true, rows, cols, trueLab, work, 2);
00157 
00158         delete[] work;
00159 
00160 //        [MacroAvg, PerClass] = macroavg(truelab, predlab);
00161 
00162         T macroAverage;
00163         T* perClass;
00164         unsigned long perClass_length;
00165 
00166         macroavg(trueLab, predLab, rows, cols, perClass, macroAverage, perClass_length);
00167 
00168         if(perClass_length > cols)
00169             throw gException(Exception_Inconsistent_Size);
00170 
00171         delete[] predLab;
00172         delete[] trueLab;
00173 
00174 //        for t = 1:length(PerClass),
00175 //            p.acc(t) = PerClass(t);
00176 //            p.forho(t) = p.acc(t);
00177 //            p.forplot(t) = p.acc(t);
00178 //        end
00179 //        for t = (length(PerClass)+1):T
00180 //            p.acc(t) = 0;
00181 //            p.forho(t) = 0;
00182 //            p.forplot(t) = 0;
00183 //        end
00184 
00185         copy(acc, perClass, perClass_length);
00186 
00187         if(perClass_length < cols)
00188             set(acc+perClass_length, (T)0.0, cols-perClass_length);
00189 
00190         delete[] perClass;
00191 
00192     }
00193 
00194     OptMatrix<gMat2D<T> >* acc_opt = new OptMatrix<gMat2D<T> >(*acc_mat);
00195     perf->addOpt("acc", acc_opt);
00196 
00197 
00198     OptMatrix<gMat2D<T> >* forho_opt = new OptMatrix<gMat2D<T> >(*(new gMat2D<T>(*acc_mat)));
00199     perf->addOpt("forho", forho_opt);
00200 
00201 //    OptMatrix<gMat2D<T> >* forplot_opt = new OptMatrix<gMat2D<T> >(*(new gMat2D<T>(*acc_mat)));
00202 //    perf->addOpt("forplot", forplot_opt);
00203 
00204     return perf;
00205 }
00206 
00210 template<typename T>
00211 void PerfMacroAvg<T>::macroavg(const unsigned long* trueY, const unsigned long* predY, const int length, int totClasses, T* &perClass, T &macroAverage, unsigned long &perClass_length)
00212 {
00213 //function [MacroAverage, PerClass] = macroavg(TrueY, PredY)
00214 //% Computes average of performance for each class.
00215 
00216 //% Macro
00217 //nClasses = max(TrueY);
00218     int nClasses = *(std::max_element(trueY, trueY+length));
00219 
00220     if(nClasses < 0)
00221         throw gException(Exception_Inconsistent_Size);
00222 
00223     perClass_length = nClasses+1;
00224 //     perClass = new T[perClass_length];
00225     perClass = new T[totClasses];
00226 
00227     unsigned long* ty_and_py = new unsigned long[length];
00228     unsigned long* num = new unsigned long[1];
00229     unsigned long* den = new unsigned long[1];
00230 
00231 //    for i = 1:nClasses,
00232     for(unsigned long i=0; i<perClass_length; ++i)
00233     {
00234 //    acc(i) = sum((TrueY == i) & (PredY == i))/(sum(TrueY == i) + eps);
00235         unsigned long* tyEqI = compare<unsigned long>(trueY, i, length, &eq);
00236         unsigned long* pyEqI = compare<unsigned long>(predY, i, length, &eq);
00237 
00238         mult(tyEqI, pyEqI, ty_and_py, length);
00239 
00240         sum(ty_and_py, num, length, 1, 1);
00241         sum(tyEqI, den, length, 1, 1);
00242 
00243         perClass[i] = ((T)(*num))/((*den) + std::numeric_limits<T>::epsilon());
00244 
00245         delete [] tyEqI;
00246         delete [] pyEqI;
00247     }
00248 
00249     delete [] ty_and_py;
00250     delete [] num;
00251     delete [] den;
00252 
00253     //set accuracy =1 on classes with no samples
00254 //    for(int i=perClass_length; i<totClasses; ++i)
00255 //      perClass[i] = 1;
00256     set(perClass+perClass_length, (T)1.0, totClasses-perClass_length);
00257 
00258 
00259 //PerClass = acc;
00260 
00261 //MacroAverage = mean(acc);
00262     T* meanValue = new T[1];
00263     mean(perClass, meanValue, nClasses+1, 1, 1);
00264 
00265     macroAverage = *meanValue;
00266 
00267     delete[] meanValue;
00268 }
00269 
00270 }
00271 
00272 #endif //_GURLS_MACROAVG_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends