GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
siglam.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_SIGLAM_H_
00044 #define _GURLS_SIGLAM_H_
00045 
00046 #include <cstdio>
00047 #include <cstring>
00048 #include <iostream>
00049 #include <cmath>
00050 #include <algorithm>
00051 #include <set>
00052 
00053 #include "gurls++/options.h"
00054 #include "gurls++/optlist.h"
00055 #include "gurls++/gmat2d.h"
00056 #include "gurls++/gvec.h"
00057 #include "gurls++/gmath.h"
00058 
00059 #include "gurls++/paramsel.h"
00060 #include "gurls++/perf.h"
00061 #include "gurls++/rbfkernel.h"
00062 #include "gurls++/loocvdual.h"
00063 
00064 namespace gurls {
00065 
00071 template <typename T>
00072 class ParamSelSiglam: public ParamSelection<T>{
00073 
00074 public:
00092     GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00093 };
00094 
00095 template <typename T>
00096 GurlsOptionsList* ParamSelSiglam<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt)
00097 {
00098     //  [n,T]  = size(y);
00099     const unsigned long t = Y.cols();
00100 
00101 //    GurlsOptionsList* kernel_old = NULL;
00102 //    if (opt.hasOpt("kernel"))
00103 //    {
00104 //        kernel_old = GurlsOptionsList::dynacast(opt.getOpt("kernel"));
00105 //        opt.removeOpt("kernel", false);
00106 //    }
00107 
00108     GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested");
00109     nestedOpt->copyOpt("nlambda", opt);
00110     nestedOpt->copyOpt("hoperf", opt);
00111     nestedOpt->copyOpt("smallnumber", opt);
00112 
00113     GurlsOptionsList* kernel = new GurlsOptionsList("kernel");
00114     kernel->addOpt("type", "rbf");
00115     nestedOpt->addOpt("kernel", kernel);
00116 
00117 
00118     GurlsOptionsList* paramsel;
00119 
00120     if(opt.hasOpt("paramsel"))
00121     {
00122         GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp");
00123         tmp_opt->copyOpt("paramsel", opt);
00124 
00125         paramsel = GurlsOptionsList::dynacast(tmp_opt->getOpt("paramsel"));
00126         tmp_opt->removeOpt("paramsel", false);
00127         delete tmp_opt;
00128 
00129         paramsel->removeOpt("lambdas");
00130         paramsel->removeOpt("sigma");
00131     }
00132     else
00133         paramsel = new GurlsOptionsList("paramsel");
00134 
00135 
00136     gMat2D<T>* dist = new gMat2D<T>(X.rows(), X.rows());
00137 
00138     // if ~isfield(opt.kernel,'distance')
00139     if(!kernel->hasOpt("distance"))
00140         //  opt.kernel.distance = squareform(pdist(X));
00141     {
00142         squareform<T>(X.getData(), X.rows(), X.cols(), dist->getData(), X.rows());
00143 
00144         T *distSquared = new T[X.rows()*X.rows()];
00145         copy(distSquared , dist->getData(), X.rows()*X.rows());
00146 
00147         mult<T>(distSquared, distSquared, dist->getData(), X.rows()*X.rows());
00148 
00149         kernel->addOpt("distance", new OptMatrix<gMat2D<T> >(*dist));
00150 
00151         delete [] distSquared;
00152     }
00153     else
00154         dist = &(kernel->getOptValue<OptMatrix<gMat2D<T> > >("distance"));
00155 
00156 
00157     //  if ~isfield(opt,'sigmamin')
00158     if(!opt.hasOpt("sigmamin"))
00159     {
00160         //  %D = sort(opt.kernel.distance);
00161         //  %opt.sigmamin = median(D(2,:));
00162         //  D = sort(squareform(opt.kernel.distance));
00163         int d_len = X.rows()*(X.rows()-1)/2;
00164         T* distY = new T[d_len];
00165         //        squareform<T>(dist->getData(), dist->rows(), dist->cols(), distY, 1);
00166 
00167         const int size = dist->cols();
00168         T* it = distY;
00169 
00170         for(int i=1; i< size; it+=i, ++i)
00171             copy(it , dist->getData()+(i*size), i);
00172 
00173         std::sort(distY, distY + d_len);
00174         //  firstPercentile = round(0.01*numel(D)+0.5);
00175 
00176         int firstPercentile = gurls::round( (T)0.01 * d_len + (T)0.5) -1;
00177 
00178         //  opt.sigmamin = D(firstPercentile);
00179         nestedOpt->addOpt("sigmamin", new OptNumber(sqrt( distY[firstPercentile]) ));
00180 
00181         delete [] distY;
00182     }
00183     else
00184     {
00185         nestedOpt->addOpt("sigmamin", new OptNumber(opt.getOptAsNumber("sigmamin")));
00186     }
00187 
00188     T sigmamin = static_cast<T>(nestedOpt->getOptAsNumber("sigmamin"));
00189 
00190     //  if ~isfield(opt,'sigmamax')
00191     if(!opt.hasOpt("sigmamax"))
00192     {
00193         //  %D = sort(opt.kernel.distance);
00194         //  %opt.sigmamax = median(D(n,:));
00195         T mAx = *(std::max_element(dist->getData(),dist->getData()+ dist->getSize()));
00196 
00197         //  opt.sigmamax = max(max(opt.kernel.distance));
00198         nestedOpt->addOpt("sigmamax", new OptNumber( sqrt( mAx )));
00199     }
00200     else
00201     {
00202         nestedOpt->addOpt("sigmamax", new OptNumber(opt.getOptAsNumber("sigmamax")));
00203     }
00204 
00205     T sigmamax = static_cast<T>(nestedOpt->getOptAsNumber("sigmamax"));
00206 
00207     // if opt.sigmamin <= 0
00208     if( le(sigmamin, (T)0.0) )
00209     {
00210         //  opt.sigmamin = eps;
00211         nestedOpt->removeOpt("sigmamin");
00212         nestedOpt->addOpt("sigmamin", new OptNumber(std::numeric_limits<T>::epsilon()));
00213         sigmamin = std::numeric_limits<T>::epsilon();
00214     }
00215 
00216     // if opt.sigmamin <= 0
00217     if( le(sigmamin, (T)0.0))
00218     {
00219         //  opt.sigmamax = eps;
00220         nestedOpt->removeOpt("sigmamax");
00221         nestedOpt->addOpt("sigmamax", new OptNumber(std::numeric_limits<T>::epsilon()));
00222         sigmamax = std::numeric_limits<T>::epsilon();
00223     }
00224 
00225     unsigned long nlambda = static_cast<unsigned long>(opt.getOptAsNumber("nlambda"));
00226     unsigned long nsigma  = static_cast<unsigned long>( opt.getOptAsNumber("nsigma"));
00227 
00228     T q = pow( sigmamax/sigmamin, static_cast<T>(1.0/(nsigma-1.0)));
00229 
00230     // LOOSQE = zeros(opt.nsigma,opt.nlambda,T);
00231     //T* LOOSQE = new T[nsigma*nlambda*t];
00232     T* perf = new T[nlambda];
00233 
00234     // sigmas = zeros(1,opt.nsigma);
00235     // for i = 1:opt.nsigma
00236 
00237     KernelRBF<T> rbfkernel;
00238     ParamSelLoocvDual<T> loocvdual;
00239 
00240     T* work = new T[std::max(nlambda, t+1)];
00241     T maxTmp = (T)-1.0;
00242     int m = -1;
00243     T guess = (T)-1.0;
00244 
00245     for(unsigned long i=0; i<nsigma; ++i)
00246     {
00247         nestedOpt->addOpt("paramsel", paramsel);
00248 
00249         paramsel->removeOpt("sigma");
00250         paramsel->addOpt("sigma", new OptNumber( sigmamin * pow(q, (T)i)));
00251 
00252         //  opt.kernel = kernel_rbf(X,y,opt);
00253         GurlsOptionsList* retKernel = rbfkernel.execute(X, Y, *nestedOpt);
00254 
00255         nestedOpt->removeOpt("kernel");
00256         nestedOpt->addOpt("kernel", retKernel);
00257 
00258         nestedOpt->removeOpt("paramsel", false);
00259 
00260         //  paramsel = paramsel_loocvdual(X,y,opt);
00261         GurlsOptionsList* ret_paramsel = loocvdual.execute(X, Y, *nestedOpt);
00262 
00263 
00264         gMat2D<T> &looe_mat = ret_paramsel->getOptValue<OptMatrix<gMat2D<T> > >("perf");
00265 
00266         //  LOOSQE(i,:,:) = paramsel.looe{1};
00267         //  guesses(i,:) = paramsel.guesses;
00268         gMat2D<T> &guesses_mat = ret_paramsel->getOptValue<OptMatrix<gMat2D<T> > >("guesses");
00269 
00270         for(unsigned long j=0; j<nlambda; ++j)
00271         {
00272             perf[j] = 0;
00273 
00274             T* end = looe_mat.getData()+looe_mat.getSize();
00275             for(T* it = looe_mat.getData()+j; it< end ; it+=nlambda)
00276                 perf[j] += *it;
00277         }
00278 
00279         unsigned long mm = std::max_element(perf, perf + nlambda) - perf;
00280 
00281         if( gt(perf[mm], maxTmp))
00282         {
00283             maxTmp = perf[mm];
00284             m = i;
00285             guess = guesses_mat.getData()[mm*guesses_mat.rows()];
00286         }
00287 
00288         delete ret_paramsel;
00289     }
00290 
00291     delete [] work;
00292     delete [] perf;
00293     delete nestedOpt;
00294 
00295     // M = sum(LOOSQE,3); % sum over classes
00296     //
00297     // [dummy,i] = max(M(:));
00298     // [m,n] = ind2sub(size(M),i);
00299     //
00300     // % opt sigma
00301     // vout.sigma = opt.sigmamin*(q^m);
00302 
00303 
00304     paramsel->removeOpt("sigma");
00305     paramsel->addOpt("sigma", new OptNumber( sigmamin * pow(q,m) ));
00306 
00307     // % opt lambda
00308     // vout.lambdas = guesses(m,n)*ones(1,T);
00309 
00310     gMat2D<T> *LAMBDA = new gMat2D<T>(1, t);
00311     set(LAMBDA->getData(), guess, t);
00312 
00313     paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*LAMBDA));
00314 
00315     return paramsel;
00316 
00317 }
00318 
00319 
00320 }
00321 
00322 #endif // _GURLS_SIGLAM_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends