GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
predkerneltraintest.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_PREDKERNELTRAINTEST_H_
00044 #define _GURLS_PREDKERNELTRAINTEST_H_
00045 
00046 
00047 #include "gurls++/predkernel.h"
00048 #include "gurls++/gmath.h"
00049 
00050 #include <string>
00051 
00052 namespace gurls {
00053 
00054 
00060 template <typename T>
00061 class PredKernelTrainTest: public PredKernel<T>
00062 {
00063 public:
00078     GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) throw(gException);
00079 };
00080 
00081 template<typename T>
00082 GurlsOptionsList *PredKernelTrainTest<T>::execute(const gMat2D<T>& X, const gMat2D<T>& /*Y*/, const GurlsOptionsList &opt) throw(gException)
00083 {
00084     const GurlsOptionsList* optimizer = opt.getOptAs<GurlsOptionsList>("optimizer");
00085 
00086     std::string kernelType = opt.getOptValue<OptString>("kernel.type");
00087 
00088     const gMat2D<T>& rls_X = optimizer->getOptValue<OptMatrix<gMat2D<T> > >("X");
00089 
00090 
00091     const unsigned long xr = X.rows();
00092     const unsigned long xc = X.cols();
00093     const unsigned long rls_xr = rls_X.rows();
00094 
00095     if(xc != rls_X.cols())
00096         throw gException(Exception_Inconsistent_Size);
00097 
00098 
00099     GurlsOptionsList* predkernel = new GurlsOptionsList("predkernel");
00100     predkernel->addOpt("type", kernelType);
00101 
00102     gMat2D<T>* K;
00103 
00104     if(kernelType == "rbf")
00105     {
00106         double sigma = opt.getOptValue<OptNumber>("paramsel.sigma");
00107 
00108 //                opt.predkernel.distance = distance(X',opt.rls.X');
00109         gMat2D<T> *dist = new gMat2D<T>(xr, rls_xr);
00110 
00111         distance_transposed(X.getData(), rls_X.getData(), xc, xr, rls_xr, dist->getData());
00112 
00113 
00114 //                fk.distance = opt.predkernel.distance;
00115         predkernel->addOpt("distance", new OptMatrix<gMat2D<T> > (*dist));
00116 
00117 
00118         K = new gMat2D<T>(xr, rls_xr);
00119         copy(K->getData(), dist->getData(), dist->getSize());
00120 
00121 //            fk.K = exp(-(opt.predkernel.distance)/(opt.paramsel.sigma^2));
00122         scal(K->getSize(), (T)(-1.0/pow(sigma, 2)), K->getData(), 1);
00123         exp(K->getData(), K->getSize());
00124 
00125         if(optimizer->hasOpt("L"))
00126         {
00127             gMat2D<T> *Ktest = new gMat2D<T>(xr, 1);
00128             set(Ktest->getData(), (T)1.0, xr);
00129 
00130             predkernel->addOpt("Ktest", new OptMatrix<gMat2D<T> >(*Ktest));
00131         }
00132 
00133     }
00134 
00135     else if(kernelType == "load")
00136     {
00137 //            load(opt.testkernel);
00138         std::string testKernel = opt.getOptAsString("testKernel");
00139 
00140 //            fk.K = K_tetr;
00141         K = new gMat2D<T>();
00142         K->load(testKernel);
00143 
00144     }
00145 
00146     else if(kernelType == "chisquared")
00147     {
00148         const T epsilon = std::numeric_limits<T>::epsilon();
00149 
00150         K = new gMat2D<T>(xr, rls_xr);
00151         T* Kbuf = K->getData();
00152         set(Kbuf, (T)0.0, K->getSize());
00153 
00154 //            for i = 1:size(X,1)
00155         for(unsigned long i=0; i<xr; ++i)
00156         {
00157 //                for j = 1:size(opt.rls.X,1)
00158             for(unsigned long j=0; j<rls_xr; ++j)
00159             {
00160 
00161 //                    fk.K(i,j) = sum(...
00162 //                                    ( (X(i,:) - opt.rls.X(j,:)).^2 ) ./ ...
00163 //                                    ( 0.5*(X(i,:) + opt.rls.X(j,:)) + eps));
00164                 T sum = 0;
00165                 for(unsigned long k=0; k< xc; ++k)
00166                 {
00167                     const T X_ik = X.getData()[i+(xr*k)];
00168                     const T rlsX_jk = rls_X.getData()[j+(rls_xr*k)];
00169 
00170                     sum += pow(X_ik - rlsX_jk, 2) / static_cast<T>(((0.5*(X_ik + rlsX_jk)) + epsilon));
00171                 }
00172 
00173                 Kbuf[i+(xr*j)] = sum;
00174             }
00175         }
00176 
00177     }
00178 
00179     else
00180         throw gException(Exception_Required_Parameter_Missing);
00181 
00182     predkernel->addOpt("K", new OptMatrix<gMat2D<T> > (*K));
00183 
00184 
00185     return predkernel;
00186 }
00187 
00188 }
00189 
00190 #endif //_GURLS_PREDKERNELTRAINTEST_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends