GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
siglamloogpregr.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_SIGLAMLOOGPREGR_H_
00044 #define _GURLS_SIGLAMLOOGPREGR_H_
00045 
00046 #include <cmath>
00047 
00048 #include "gurls++/options.h"
00049 #include "gurls++/optlist.h"
00050 #include "gurls++/gmat2d.h"
00051 #include "gurls++/gvec.h"
00052 #include "gurls++/gmath.h"
00053 
00054 #include "gurls++/paramsel.h"
00055 #include "gurls++/loogpregr.h"
00056 #include "gurls++/rbfkernel.h"
00057 
00058 namespace gurls {
00059 
00065 template <typename T>
00066 class ParamSelSiglamLooGPRegr: public ParamSelection<T>{
00067 
00068 public:
00087    GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00088 };
00089 
00090 template <typename T>
00091 GurlsOptionsList *ParamSelSiglamLooGPRegr<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt)
00092 {
00093 //    [n,T]  = size(y);
00094     const unsigned long n = Y.rows();
00095     const unsigned long t = Y.cols();
00096 
00097     const unsigned long d = X.cols();
00098 
00099     GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested");
00100     nestedOpt->copyOpt("nlambda", opt);
00101     nestedOpt->copyOpt("hoperf", opt);
00102     nestedOpt->copyOpt("singlelambda", opt);
00103 
00104 //    if ~isfield(opt,'kernel')
00105     if(!opt.hasOpt("kernel"))
00106     {
00107 //        opt.kernel.type = 'rbf';
00108         GurlsOptionsList* kernel = new GurlsOptionsList("kernel");
00109         kernel->addOpt("type", "rbf");
00110 
00111         nestedOpt->addOpt("kernel", kernel);
00112     }
00113     else
00114         nestedOpt->copyOpt("kernel", opt);
00115 
00116 
00117     GurlsOptionsList* kernel = nestedOpt->getOptAs<GurlsOptionsList>("kernel");
00118 
00119 
00120     gMat2D<T> *distance;
00121 
00122 //    if ~isfield(opt.kernel,'distance')
00123     if(!kernel->hasOpt("distance"))
00124     {
00125         distance = new gMat2D<T>(n, n);
00126 
00127 //        opt.kernel.distance = square_distance(X',X');
00128         distance_transposed(X.getData(), X.getData(), d, n, n, distance->getData());
00129 
00130         kernel->addOpt("distance", new OptMatrix<gMat2D<T> >(*distance));
00131     }
00132     else
00133         distance = &(kernel->getOptValue<OptMatrix<gMat2D<T> > >("distance"));
00134 
00135 
00136 //    if ~isfield(opt,'sigmamin')
00137     if(!opt.hasOpt("sigmamin"))
00138     {
00139         int d_len = n*(n-1)/2;
00140         T* distLinearized = new T[d_len];
00141 
00142 //        D = sort(opt.kernel.distance(tril(true(n),-1)));
00143 
00144         const unsigned long size = distance->cols();
00145         T* it = distLinearized;
00146         T* d_it = distance->getData();
00147 
00148         for(unsigned long i=1; i< size; ++i)
00149         {
00150             gurls::copy(it , d_it+i, size - i);
00151 
00152             it += size - i;
00153             d_it += size;
00154         }
00155 
00156         std::sort(distLinearized, distLinearized + d_len);
00157 
00158         //        firstPercentile = round(0.01*numel(D)+0.5);
00159         int firstPercentile = gurls::round((T)0.01 * d_len + (T)0.5)-1;
00160 
00161         //  opt.sigmamin = D(firstPercentile);
00162         nestedOpt->addOpt("sigmamin", new OptNumber(sqrt( distLinearized[firstPercentile]) ));
00163 
00164         delete [] distLinearized;
00165     }
00166     else
00167         nestedOpt->copyOpt("sigmamin", opt);
00168 
00169 
00170 //    if ~isfield(opt,'sigmamax')
00171     if(!opt.hasOpt("sigmamax"))
00172     {
00173 //        opt.sigmamax = sqrt(max(max(opt.kernel.distance)));
00174         double sigmaMax = sqrt(*std::max_element(distance->getData(), distance->getData()+distance->getSize()));
00175         nestedOpt->addOpt("sigmamax", new OptNumber(sigmaMax));
00176     }
00177     else
00178         nestedOpt->copyOpt("sigmamax", opt);
00179 
00180 
00181 //    if opt.sigmamin <= 0
00182     if( le(nestedOpt->getOptAsNumber("sigmamin"), 0.0) )
00183     {
00184 //        opt.sigmamin = eps;
00185         nestedOpt->removeOpt("sigmamin");
00186         nestedOpt->addOpt("sigmamin", new OptNumber(std::numeric_limits<T>::epsilon()));
00187     }
00188 //    if opt.sigmamax <= 0
00189     if( le(nestedOpt->getOptAsNumber("sigmamax"), 0.0) )
00190     {
00191 //        opt.sigmamax = eps;
00192         nestedOpt->removeOpt("sigmamax");
00193         nestedOpt->addOpt("sigmamax", new OptNumber(std::numeric_limits<T>::epsilon()));
00194     }
00195 
00196     const int nsigma = static_cast<int>(opt.getOptAsNumber("nsigma"));
00197     const int nlambda = static_cast<int>(opt.getOptAsNumber("nlambda"));
00198     const T sigmamin = static_cast<T>(nestedOpt->getOptAsNumber("sigmamin"));
00199 
00200 //    q = (opt.sigmamax/opt.sigmamin)^(1/(opt.nsigma-1));
00201     T q = static_cast<T>(std::pow(nestedOpt->getOptAsNumber("sigmamax")/nestedOpt->getOptAsNumber("sigmamin"), 1.0/(nsigma-1.0)));
00202 
00203 //    perf = zeros(opt.nsigma,opt.nlambda,T);
00204     T* perf = new T[nsigma*nlambda];
00205     set(perf, (T)0.0, nsigma*nlambda);
00206 
00207 //    sigmas = zeros(1,opt.nsigma);
00208 //    T* sigmas = new T[nsigma];
00209 
00210     T* guesses = new T[nsigma*nlambda];
00211 
00212     KernelRBF<T> rbf;
00213     ParamSelLooGPRegr<T> loogp;
00214 
00215     T* work = new T[t];
00216 
00217     GurlsOptionsList* paramsel_rbf = new GurlsOptionsList("paramsel");
00218     nestedOpt->addOpt("paramsel", paramsel_rbf);
00219 
00220 //    for i = 1:opt.nsigma
00221     for(int i=0; i<nsigma; ++i)
00222     {
00223 //        sigmas(i) = (opt.sigmamin*(q^(i-1)));
00224         const T sigma = sigmamin* std::pow(q, i);
00225 //        sigmas[i] = sigma;
00226 
00227 //        opt.paramsel.sigma = sigmas(i);
00228         paramsel_rbf->removeOpt("sigma");
00229         paramsel_rbf->addOpt("sigma", new OptNumber(sigma));
00230 
00231 
00232 //        opt.kernel = kernel_rbf(X,y,opt);
00233         GurlsOptionsList* rbf_kernel = rbf.execute(X, Y, *nestedOpt);
00234 
00235         nestedOpt->removeOpt("kernel");
00236         nestedOpt->addOpt("kernel", rbf_kernel);
00237 
00238 //        paramsel = paramsel_loogpregr(X,y,opt);
00239         GurlsOptionsList* paramsel_loogp = loogp.execute(X, Y, *nestedOpt);
00240 
00241 //        perf(i,:,:) = paramsel.perf;
00242         gMat2D<T> &perf_mat = paramsel_loogp->getOptValue<OptMatrix<gMat2D<T> > >("perf"); // nlambda x t
00243 
00244         T* perf_it = perf + i;
00245         for(int j=0; j<nlambda; ++j, perf_it += nsigma)
00246         {
00247             getRow(perf_mat.getData(), perf_mat.rows(), perf_mat.cols(), j, work);
00248 
00249             *perf_it = sumv(work, t);
00250         }
00251 
00252 //        guesses(i,:) = paramsel.guesses;
00253         gMat2D<T> &guesses_mat = paramsel_loogp->getOptValue<OptMatrix<gMat2D<T> > >("guesses"); // nlambda x 1
00254         copy(guesses+i, guesses_mat.getData(), nlambda, nsigma, 1);
00255 
00256         delete paramsel_loogp;
00257     }
00258 
00259     delete [] work;
00260     delete nestedOpt;
00261 
00262 //    M = sum(perf,3); % sum over classes
00263 //    [dummy,i] = max(M(:));
00264     int i = std::max_element(perf, perf +(nsigma*nlambda)) - perf;
00265 
00266 //    [m,n] = ind2sub(size(M),i);
00267     int im = i%nsigma;
00268 
00269     delete [] perf;
00270 
00271     GurlsOptionsList* paramsel = new GurlsOptionsList("paramsel");
00272 
00273 //    % opt sigma
00274 //    vout.sigma = opt.sigmamin*(q^(m-1));
00275     paramsel->addOpt("sigma", new OptNumber( sigmamin*(std::pow(q, im))) );
00276 
00277 //    % opt lambda
00278 //    vout.noises = guesses(m,n)*ones(1,T);
00279 
00280     gMat2D<T> *lambdas = new gMat2D<T>(1, t);
00281     set(lambdas->getData(), guesses[i], t);
00282 
00283     paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*lambdas));
00284 
00285     delete [] guesses;
00286 
00287     return paramsel;
00288 }
00289 
00290 
00291 }
00292 
00293 #endif // _GURLS_SIGLAMLOOGPREGR_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends