![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_MACROAVG_H_ 00044 #define _GURLS_MACROAVG_H_ 00045 00046 #include "gurls++/perf.h" 00047 00048 #include "gurls++/utils.h" 00049 #include "gurls++/gvec.h" 00050 #include "gurls++/optmatrix.h" 00051 00052 #include "float.h" 00053 00054 namespace gurls { 00055 00061 template <typename T> 00062 class PerfMacroAvg: public Performance<T>{ 00063 00064 public: 00077 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) throw(gException); 00078 00079 protected: 00080 void macroavg(const unsigned long* trueY, const unsigned long* predY, const int length,int totClasses, T* &perClass, T ¯oAverage, unsigned long &perClass_length); 00081 }; 00082 00083 template<typename T> 00084 GurlsOptionsList* PerfMacroAvg<T>::execute(const gMat2D<T>& /*X*/, const gMat2D<T>& Y, const GurlsOptionsList& opt) throw(gException) 00085 { 00086 const unsigned long rows = Y.rows(); 00087 const unsigned long cols = Y.cols(); 00088 00089 00090 // if isfield (opt,'perf') 00091 // p = opt.perf; % lets not overwrite existing performance measures. 00092 // % unless they have the same name 00093 // end 00094 00095 GurlsOptionsList* perf = NULL; 00096 00097 if(opt.hasOpt("perf")) 00098 { 00099 GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp"); 00100 tmp_opt->copyOpt("perf", opt); 00101 00102 perf = GurlsOptionsList::dynacast(tmp_opt->getOpt("perf")); 00103 tmp_opt->removeOpt("perf", false); 00104 delete tmp_opt; 00105 00106 perf->removeOpt("acc"); 00107 perf->removeOpt("forho"); 00108 // perf->removeOpt("forplot"); 00109 } 00110 else 00111 perf = new GurlsOptionsList("perf"); 00112 00113 00114 gMat2D<T>* acc_mat = new gMat2D<T>(1, cols); 00115 T* acc = acc_mat->getData(); 00116 00117 // T = size(y,2); 00118 00119 // y_true = y; 00120 const T* y_true = Y.getData(); 00121 00122 00123 // y_pred = opt.pred; 00124 const gMat2D<T> &y_pred = opt.getOptValue<OptMatrix<gMat2D<T> > >("pred"); 00125 00126 00127 // if size(y,2) == 1 00128 if(cols == 1) 00129 { 00130 // predlab = sign(y_pred); 00131 T* predLab = sign(y_pred.getData(), y_pred.getSize()); 00132 00133 T* tmp = compare<T>(predLab, Y.getData(), rows, &eq); 00134 00135 // p.acc = mean(predlab == y); 00136 mean(tmp, acc, rows, 1, 1); 00137 00138 // p.forho = mean(predlab == y); 00139 // p.forplot = mean(predlab == y); 00140 00141 delete [] tmp; 00142 delete [] predLab; 00143 } 00144 else 00145 { 00146 // %% Assumes single label prediction. 00147 // [dummy, predlab] = max(y_pred,[],2); 00148 T* work = new T[std::max(Y.getSize(), y_pred.getSize() )]; 00149 00150 unsigned long* predLab = new unsigned long[rows]; 00151 indicesOfMax(y_pred.getData(), rows, y_pred.cols(), predLab, work, 2); 00152 00153 // [dummy, truelab] = max(y_true,[],2); 00154 // unsigned long* trueLab = indicesOfMax(y_true, rows, cols, 2); 00155 unsigned long* trueLab = new unsigned long[rows]; 00156 indicesOfMax(y_true, rows, cols, trueLab, work, 2); 00157 00158 delete[] work; 00159 00160 // [MacroAvg, PerClass] = macroavg(truelab, predlab); 00161 00162 T macroAverage; 00163 T* perClass; 00164 unsigned long perClass_length; 00165 00166 macroavg(trueLab, predLab, rows, cols, perClass, macroAverage, perClass_length); 00167 00168 if(perClass_length > cols) 00169 throw gException(Exception_Inconsistent_Size); 00170 00171 delete[] predLab; 00172 delete[] trueLab; 00173 00174 // for t = 1:length(PerClass), 00175 // p.acc(t) = PerClass(t); 00176 // p.forho(t) = p.acc(t); 00177 // p.forplot(t) = p.acc(t); 00178 // end 00179 // for t = (length(PerClass)+1):T 00180 // p.acc(t) = 0; 00181 // p.forho(t) = 0; 00182 // p.forplot(t) = 0; 00183 // end 00184 00185 copy(acc, perClass, perClass_length); 00186 00187 if(perClass_length < cols) 00188 set(acc+perClass_length, (T)0.0, cols-perClass_length); 00189 00190 delete[] perClass; 00191 00192 } 00193 00194 OptMatrix<gMat2D<T> >* acc_opt = new OptMatrix<gMat2D<T> >(*acc_mat); 00195 perf->addOpt("acc", acc_opt); 00196 00197 00198 OptMatrix<gMat2D<T> >* forho_opt = new OptMatrix<gMat2D<T> >(*(new gMat2D<T>(*acc_mat))); 00199 perf->addOpt("forho", forho_opt); 00200 00201 // OptMatrix<gMat2D<T> >* forplot_opt = new OptMatrix<gMat2D<T> >(*(new gMat2D<T>(*acc_mat))); 00202 // perf->addOpt("forplot", forplot_opt); 00203 00204 return perf; 00205 } 00206 00210 template<typename T> 00211 void PerfMacroAvg<T>::macroavg(const unsigned long* trueY, const unsigned long* predY, const int length, int totClasses, T* &perClass, T ¯oAverage, unsigned long &perClass_length) 00212 { 00213 //function [MacroAverage, PerClass] = macroavg(TrueY, PredY) 00214 //% Computes average of performance for each class. 00215 00216 //% Macro 00217 //nClasses = max(TrueY); 00218 int nClasses = *(std::max_element(trueY, trueY+length)); 00219 00220 if(nClasses < 0) 00221 throw gException(Exception_Inconsistent_Size); 00222 00223 perClass_length = nClasses+1; 00224 // perClass = new T[perClass_length]; 00225 perClass = new T[totClasses]; 00226 00227 unsigned long* ty_and_py = new unsigned long[length]; 00228 unsigned long* num = new unsigned long[1]; 00229 unsigned long* den = new unsigned long[1]; 00230 00231 // for i = 1:nClasses, 00232 for(unsigned long i=0; i<perClass_length; ++i) 00233 { 00234 // acc(i) = sum((TrueY == i) & (PredY == i))/(sum(TrueY == i) + eps); 00235 unsigned long* tyEqI = compare<unsigned long>(trueY, i, length, &eq); 00236 unsigned long* pyEqI = compare<unsigned long>(predY, i, length, &eq); 00237 00238 mult(tyEqI, pyEqI, ty_and_py, length); 00239 00240 sum(ty_and_py, num, length, 1, 1); 00241 sum(tyEqI, den, length, 1, 1); 00242 00243 perClass[i] = ((T)(*num))/((*den) + std::numeric_limits<T>::epsilon()); 00244 00245 delete [] tyEqI; 00246 delete [] pyEqI; 00247 } 00248 00249 delete [] ty_and_py; 00250 delete [] num; 00251 delete [] den; 00252 00253 //set accuracy =1 on classes with no samples 00254 // for(int i=perClass_length; i<totClasses; ++i) 00255 // perClass[i] = 1; 00256 set(perClass+perClass_length, (T)1.0, totClasses-perClass_length); 00257 00258 00259 //PerClass = acc; 00260 00261 //MacroAverage = mean(acc); 00262 T* meanValue = new T[1]; 00263 mean(perClass, meanValue, nClasses+1, 1, 1); 00264 00265 macroAverage = *meanValue; 00266 00267 delete[] meanValue; 00268 } 00269 00270 } 00271 00272 #endif //_GURLS_MACROAVG_H_