GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
loocvdual.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_LOOCVDUAL_H_
00044 #define _GURLS_LOOCVDUAL_H_
00045 
00046 #include <cstdio>
00047 #include <cstring>
00048 #include <iostream>
00049 #include <cmath>
00050 #include <algorithm>
00051 #include <set>
00052 
00053 #include "gurls++/options.h"
00054 #include "gurls++/optlist.h"
00055 #include "gurls++/gmat2d.h"
00056 #include "gurls++/gvec.h"
00057 #include "gurls++/gmath.h"
00058 #include "gurls++/utils.h"
00059 
00060 #include "gurls++/paramsel.h"
00061 #include "gurls++/perf.h"
00062 
00063 namespace gurls {
00064 
00070 template <typename T>
00071 class ParamSelLoocvDual: public ParamSelection<T>{
00072 
00073 public:
00090    GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00091 };
00092 
00093 template <typename T>
00094 GurlsOptionsList* ParamSelLoocvDual<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt)
00095 {
00096 //    [n,T]  = size(y);
00097     const unsigned long n = Y.rows();
00098     const unsigned long t = Y.cols();
00099 
00100     const unsigned long d = X.cols();
00101 
00102 
00103 //    tot = opt.nlambda;
00104     int tot = static_cast<int>(std::ceil( opt.getOptAsNumber("nlambda")));
00105 
00106 
00107 //    [Q,L] = eig(opt.kernel.K);
00108 //    Q = double(Q);
00109 //    L = double(diag(L));
00110     const GurlsOptionsList* kernel = opt.getOptAs<GurlsOptionsList>("kernel");
00111 
00112     const gMat2D<T> &K_mat = kernel->getOptValue<OptMatrix<gMat2D<T> > >("K");
00113 
00114     gMat2D<T> K(K_mat.rows(), K_mat.cols());
00115     copy(K.getData(), K_mat.getData(), K_mat.getSize());
00116 
00117     const unsigned long qrows = K.rows();
00118     const unsigned long qcols = K.cols();
00119     const unsigned long l_length = qrows;
00120 
00121     T *Q = K.getData();
00122     T *L = new T[l_length];
00123 
00124     eig_sm(Q, L, qrows); // qrows == qcols
00125 
00126     int r = n;
00127     if(kernel->getOptAsString("type") == "linear")
00128     {
00129         set(L, (T) 1.0e-12, l_length-d);
00130         r = std::min(n,d);
00131     }
00132 
00133 
00134 //    Qty = Q'*y;
00135 //    T* Qty = new T[qrows*qcols];
00136     T* Qty = new T[qcols*t];
00137     dot(Q, Y.getData(), Qty, qrows, qcols, n, t, qcols, t, CblasTrans, CblasNoTrans, CblasColMajor);
00138 
00139     T* guesses = lambdaguesses(L, n, r, n, tot, (T)(opt.getOptAsNumber("smallnumber")));
00140 
00141 
00142 
00143     GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested");
00144 
00145     gMat2D<T>* pred = new gMat2D<T>(n, t);
00146     OptMatrix<gMat2D<T> >* pred_opt = new OptMatrix<gMat2D<T> >(*pred);
00147     nestedOpt->addOpt("pred", pred_opt);
00148 
00149 
00150     Performance<T>* perfClass = Performance<T>::factory(opt.getOptAsString("hoperf"));
00151 
00152     gMat2D<T>* perf = new gMat2D<T>(tot, t);
00153     T* ap = perf->getData();
00154 
00155     T* C_div_Z = new T[qrows];
00156     T* C = new T[qrows*qcols];
00157     T* Z = new T[qrows];
00158     T* work = new T[std::max((qrows+1)*l_length, (qrows*qcols)+l_length)];
00159 
00160     for(int i = 0; i < tot; ++i)
00161     {
00162         rls_eigen(Q, L, Qty, C, guesses[i], n, qrows, qcols, l_length, qcols, t, work);
00163         GInverseDiagonal(Q, L, guesses+i, Z, qrows, qcols, l_length, 1, work);
00164 
00165         for(unsigned long j = 0; j< t; ++j)
00166         {
00167             rdivide(C + (qrows*j), Z, C_div_Z, qrows);
00168 
00169 //            opt.pred(:,t) = y(:,t) - (C(:,t)./Z);
00170             copy(pred->getData()+(n*j), Y.getData() + (n*j), n);
00171             axpy(n, (T)-1.0, C_div_Z, 1, pred->getData() + (n*j), 1);
00172         }
00173 
00174 //        opt.perf = opt.hoperf([],y,opt);
00175         const gMat2D<T> dummy;
00176         GurlsOptionsList* perf_opt = perfClass->execute(dummy, Y, *nestedOpt);
00177 
00178         gMat2D<T> &forho_vec = perf_opt->getOptValue<OptMatrix<gMat2D<T> > >("forho");
00179 
00180         copy(ap+i, forho_vec.getData(), t, tot, 1);
00181 
00182         delete perf_opt;
00183     }
00184 
00185     delete nestedOpt;
00186     delete[] work;
00187     delete [] C;
00188     delete [] Z;
00189     delete [] C_div_Z;
00190     delete perfClass;
00191     delete [] Qty;
00192 
00193     delete[] L;
00194     //delete[] Q;
00195 
00196     unsigned long* idx = new unsigned long[t];
00197     work = NULL;
00198     indicesOfMax(ap, tot, t, idx, work, 1);
00199 
00200 
00201     gMat2D<T> *LAMBDA = new gMat2D<T>(1, t);
00202     copyLocations(idx, guesses, t, tot, LAMBDA->getData());
00203 
00204     delete[] idx;
00205 
00206 
00207     GurlsOptionsList* paramsel;
00208 
00209     if(opt.hasOpt("paramsel"))
00210     {
00211         GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp");
00212         tmp_opt->copyOpt("paramsel", opt);
00213 
00214         paramsel = GurlsOptionsList::dynacast(tmp_opt->getOpt("paramsel"));
00215         tmp_opt->removeOpt("paramsel", false);
00216         delete tmp_opt;
00217 
00218         paramsel->removeOpt("guesses");
00219         paramsel->removeOpt("perf");
00220         paramsel->removeOpt("lambdas");
00221     }
00222     else
00223         paramsel = new GurlsOptionsList("paramsel");
00224 
00225 
00226 //     opt.addOpt("lambdas", LAMBDA);
00227     paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*LAMBDA));
00228 
00229     //vout.perf =   ap;
00230     paramsel->addOpt("perf", new OptMatrix<gMat2D<T> >(*perf));
00231 
00232     //vout.guesses =    guesses;
00233     gMat2D<T> *guesses_mat = new gMat2D<T>(guesses, 1, tot, true);
00234     paramsel->addOpt("guesses", new OptMatrix<gMat2D<T> >(*guesses_mat));
00235 
00236     delete[] guesses;
00237 
00238     return paramsel;
00239 
00240 }
00241 
00242 
00243 }
00244 
00245 #endif // _GURLS_LOOCVDUAL_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends