![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_PREDKERNELTRAINTEST_H_ 00044 #define _GURLS_PREDKERNELTRAINTEST_H_ 00045 00046 00047 #include "gurls++/predkernel.h" 00048 #include "gurls++/gmath.h" 00049 00050 #include <string> 00051 00052 namespace gurls { 00053 00054 00060 template <typename T> 00061 class PredKernelTrainTest: public PredKernel<T> 00062 { 00063 public: 00078 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) throw(gException); 00079 }; 00080 00081 template<typename T> 00082 GurlsOptionsList *PredKernelTrainTest<T>::execute(const gMat2D<T>& X, const gMat2D<T>& /*Y*/, const GurlsOptionsList &opt) throw(gException) 00083 { 00084 const GurlsOptionsList* optimizer = opt.getOptAs<GurlsOptionsList>("optimizer"); 00085 00086 std::string kernelType = opt.getOptValue<OptString>("kernel.type"); 00087 00088 const gMat2D<T>& rls_X = optimizer->getOptValue<OptMatrix<gMat2D<T> > >("X"); 00089 00090 00091 const unsigned long xr = X.rows(); 00092 const unsigned long xc = X.cols(); 00093 const unsigned long rls_xr = rls_X.rows(); 00094 00095 if(xc != rls_X.cols()) 00096 throw gException(Exception_Inconsistent_Size); 00097 00098 00099 GurlsOptionsList* predkernel = new GurlsOptionsList("predkernel"); 00100 predkernel->addOpt("type", kernelType); 00101 00102 gMat2D<T>* K; 00103 00104 if(kernelType == "rbf") 00105 { 00106 double sigma = opt.getOptValue<OptNumber>("paramsel.sigma"); 00107 00108 // opt.predkernel.distance = distance(X',opt.rls.X'); 00109 gMat2D<T> *dist = new gMat2D<T>(xr, rls_xr); 00110 00111 distance_transposed(X.getData(), rls_X.getData(), xc, xr, rls_xr, dist->getData()); 00112 00113 00114 // fk.distance = opt.predkernel.distance; 00115 predkernel->addOpt("distance", new OptMatrix<gMat2D<T> > (*dist)); 00116 00117 00118 K = new gMat2D<T>(xr, rls_xr); 00119 copy(K->getData(), dist->getData(), dist->getSize()); 00120 00121 // fk.K = exp(-(opt.predkernel.distance)/(opt.paramsel.sigma^2)); 00122 scal(K->getSize(), (T)(-1.0/pow(sigma, 2)), K->getData(), 1); 00123 exp(K->getData(), K->getSize()); 00124 00125 if(optimizer->hasOpt("L")) 00126 { 00127 gMat2D<T> *Ktest = new gMat2D<T>(xr, 1); 00128 set(Ktest->getData(), (T)1.0, xr); 00129 00130 predkernel->addOpt("Ktest", new OptMatrix<gMat2D<T> >(*Ktest)); 00131 } 00132 00133 } 00134 00135 else if(kernelType == "load") 00136 { 00137 // load(opt.testkernel); 00138 std::string testKernel = opt.getOptAsString("testKernel"); 00139 00140 // fk.K = K_tetr; 00141 K = new gMat2D<T>(); 00142 K->load(testKernel); 00143 00144 } 00145 00146 else if(kernelType == "chisquared") 00147 { 00148 const T epsilon = std::numeric_limits<T>::epsilon(); 00149 00150 K = new gMat2D<T>(xr, rls_xr); 00151 T* Kbuf = K->getData(); 00152 set(Kbuf, (T)0.0, K->getSize()); 00153 00154 // for i = 1:size(X,1) 00155 for(unsigned long i=0; i<xr; ++i) 00156 { 00157 // for j = 1:size(opt.rls.X,1) 00158 for(unsigned long j=0; j<rls_xr; ++j) 00159 { 00160 00161 // fk.K(i,j) = sum(... 00162 // ( (X(i,:) - opt.rls.X(j,:)).^2 ) ./ ... 00163 // ( 0.5*(X(i,:) + opt.rls.X(j,:)) + eps)); 00164 T sum = 0; 00165 for(unsigned long k=0; k< xc; ++k) 00166 { 00167 const T X_ik = X.getData()[i+(xr*k)]; 00168 const T rlsX_jk = rls_X.getData()[j+(rls_xr*k)]; 00169 00170 sum += pow(X_ik - rlsX_jk, 2) / static_cast<T>(((0.5*(X_ik + rlsX_jk)) + epsilon)); 00171 } 00172 00173 Kbuf[i+(xr*j)] = sum; 00174 } 00175 } 00176 00177 } 00178 00179 else 00180 throw gException(Exception_Required_Parameter_Missing); 00181 00182 predkernel->addOpt("K", new OptMatrix<gMat2D<T> > (*K)); 00183 00184 00185 return predkernel; 00186 } 00187 00188 } 00189 00190 #endif //_GURLS_PREDKERNELTRAINTEST_H_