GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
loocvprimal.h
00001 /*
00002   * The GURLS Package in C++
00003   *
00004   * Copyright (C) 2011-2013, IIT@MIT Lab
00005   * All rights reserved.
00006   *
00007   * authors:  M. Santoro
00008   * email:   msantoro@mit.edu
00009   * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010   *
00011   * Redistribution and use in source and binary forms, with or without
00012   * modification, are permitted provided that the following conditions
00013   * are met:
00014   *
00015   *     * Redistributions of source code must retain the above
00016   *       copyright notice, this list of conditions and the following
00017   *       disclaimer.
00018   *     * Redistributions in binary form must reproduce the above
00019   *       copyright notice, this list of conditions and the following
00020   *       disclaimer in the documentation and/or other materials
00021   *       provided with the distribution.
00022   *     * Neither the name(s) of the copyright holders nor the names
00023   *       of its contributors or of the Massacusetts Institute of
00024   *       Technology or of the Italian Institute of Technology may be
00025   *       used to endorse or promote products derived from this software
00026   *       without specific prior written permission.
00027   *
00028   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029   * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030   * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031   * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032   * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033   * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034   * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036   * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037   * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038   * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039   * POSSIBILITY OF SUCH DAMAGE.
00040   */
00041 
00042 
00043 
00044 #ifndef _GURLS_LOOCVPRIMAL_H_
00045 #define _GURLS_LOOCVPRIMAL_H_
00046 
00047 #include <cstdio>
00048 #include <cstring>
00049 #include <iostream>
00050 #include <cmath>
00051 #include <algorithm>
00052 #include <set>
00053 
00054 #include "gurls++/options.h"
00055 #include "gurls++/optlist.h"
00056 #include "gurls++/optmatrix.h"
00057 
00058 #include "gurls++/gmat2d.h"
00059 #include "gurls++/gvec.h"
00060 #include "gurls++/gmath.h"
00061 
00062 #include "gurls++/paramsel.h"
00063 
00064 #include "gurls++/precisionrecall.h"
00065 #include "gurls++/macroavg.h"
00066 #include "gurls++/rmse.h"
00067 
00068 namespace gurls {
00069 
00075 template <typename T>
00076 class ParamSelLoocvPrimal: public ParamSelection<T>
00077 {
00078 
00079 public:
00093     GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00094 };
00095 
00096 template <typename T>
00097 GurlsOptionsList *ParamSelLoocvPrimal<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt)
00098 {
00099     typename std::set<T*> garbage;
00100 
00101     try
00102     {
00103         //[n,T]  = size(y);
00104         const unsigned long n = Y.rows();
00105         const unsigned long t = Y.cols();
00106 
00107 
00108         //  K = X'*X;
00109         const unsigned long xc = X.cols();
00110         const unsigned long xr = X.rows();
00111 
00112         if(xr != n)
00113             throw gException(Exception_Inconsistent_Size);
00114 
00115         T* K = new T[xc*xc];
00116         garbage.insert(K);
00117         dot(X.getData(), X.getData(), K, xr, xc, xr, xc, xc, xc, CblasTrans, CblasNoTrans, CblasColMajor);
00118 
00119 
00120         // tot = opt.nlambda;
00121         int tot = static_cast<int>(std::ceil( opt.getOptAsNumber("nlambda")));
00122 
00123         //  [Q,L] = eig(K);
00124         //  L = diag(L);
00125 
00126 //        T* Q, *L;
00127 
00128 //        eig(K, Q, L, xc, xc);
00129 //        garbage.insert(Q);
00130 //        garbage.insert(L);
00131 
00132         T* Q = K;
00133         T* L = new T[xc];
00134         garbage.insert(L);
00135 
00136         eig_sm(Q, L, xc);
00137 
00138 
00139         T* filtered = L;
00140         T* guesses = lambdaguesses(filtered, xc, std::min(xc,xr), xr, tot, (T)(opt.getOptAsNumber("smallnumber")));
00141         garbage.insert(guesses);
00142 //        set(guesses, 0.f, tot);
00143 
00144         T* LOOSQE = new T[tot*t];
00145         garbage.insert(LOOSQE);
00146         set(LOOSQE, (T)0.0, tot*t);
00147 
00148         //  LEFT = X*Q;
00149         //  RIGHT = Q'*X'*y;
00150         T* LEFT = new T[xr*xc];
00151         garbage.insert(LEFT);
00152         dot(X.getData(), Q, LEFT, xr, xc, xc, xc, xr, xc, CblasNoTrans, CblasNoTrans, CblasColMajor);
00153 
00154         T* tmp = new T[xc*t];
00155         garbage.insert(tmp);
00156         dot(X.getData(), Y.getData(), tmp, xr, xc, n, t, xc, t, CblasTrans, CblasNoTrans, CblasColMajor);
00157 
00158         T* RIGHT = new T[xc*t];
00159         garbage.insert(RIGHT);
00160         dot(Q, tmp, RIGHT, xc, xc, xc, t, xc, t, CblasTrans, CblasNoTrans, CblasColMajor);
00161 
00162         delete[] tmp;
00163         garbage.erase(tmp);
00164 
00165         //  right = Q'*X';
00166         T* right = new T[xc*xr];
00167         garbage.insert(right);
00168         dot(Q, X.getData(), right, xc, xc, xr, xc, xc, xr, CblasTrans, CblasTrans, CblasColMajor);
00169 
00170         delete[] Q;
00171         garbage.erase(Q);
00172 
00173         T* den = new T[n];
00174 //        T* Le = new T[n*t];
00175         garbage.insert(den);
00176 //        garbage.insert(Le);
00177 
00178         T* tmpvec = new T[xc];
00179         garbage.insert(tmpvec);
00180         T* tmp1 = new T[xr*t];
00181         garbage.insert(tmp1);
00182         T* LL = new T[xc*xc];
00183         garbage.insert(LL);
00184         tmp = new T[xc*t];
00185         garbage.insert(tmp);
00186         T* num = new T[xr*t];
00187         garbage.insert(num);
00188         T* row = new T[xc];
00189         garbage.insert(row);
00190         T* num_div_den = new T[n];
00191         garbage.insert(num_div_den);
00192 
00193         GurlsOptionsList* nestedOpt = new GurlsOptionsList("nested");
00194 
00195         gMat2D<T>* pred = new gMat2D<T>(n, t);
00196         OptMatrix<gMat2D<T> >* pred_opt = new OptMatrix<gMat2D<T> >(*pred);
00197         nestedOpt->addOpt("pred", pred_opt);
00198 
00199 //        const int pred_size = pred->getSize();
00200 
00201         Performance<T>* perfClass = Performance<T>::factory(opt.getOptAsString("hoperf"));
00202 
00203         gMat2D<T>* perf = new gMat2D<T>(tot, t);
00204         T* ap = perf->getData();
00205 
00206 
00207         //  for i = 1:tot
00208         for(int s = 0; s < tot; ++s)
00209         {
00210 
00211             //      LL = L + (n*guesses(i));
00212             set(tmpvec, n*guesses[s] , xc);
00213             axpy(xc, (T)1.0, L, 1, tmpvec, 1);
00214 
00215             //      LL = LL.^(-1)
00216             setReciprocal(tmpvec, xc);
00217             //      LL = diag(LL);
00218             diag(tmpvec, xc, LL);
00219 
00220 
00221             //      num = y - LEFT*LL*RIGHT;
00222 
00223             dot(LL, RIGHT, tmp, xc, xc, xc, t, xc, t, CblasNoTrans, CblasNoTrans, CblasColMajor);
00224             dot(LEFT, tmp, tmp1, xr, xc, xc, t, xr, t, CblasNoTrans, CblasNoTrans, CblasColMajor);
00225 
00226 
00227             copy(num, Y.getData(), xr*t);
00228             axpy(xr*t, (T)-1.0, tmp1, 1, num, 1);
00229 
00230 
00231             // den = zeros(n,1);
00232             set(den, (T)0.0, n);
00233 
00234             //      for j = 1:n
00235             //          den(j) = 1-LEFT(j,:)*LL*right(:,j);
00236             //      end
00237 
00238 
00239             for (unsigned long j = 0; j < n; ++j)
00240             {
00241                 dot(LL, right +(xc*j), tmp, xc, xc, xc, 1, xc, 1, CblasNoTrans, CblasNoTrans, CblasColMajor);
00242 
00243                 //extract j-th row from LEFT
00244                 copy(row, LEFT + j, xc, 1, xr);
00245 
00246                 den[j] =  ((T) 1.0) - dot (xc, row, 1, tmp, 1);
00247             }
00248 
00249 
00250     //        opt.pred = zeros(n,T);
00251 //            set(pred->getData(), (T)0.0, pred_size);
00252 
00253 
00254 
00255     //        for t = 1:T
00256             for(unsigned long j = 0; j< t; ++j)
00257             {
00258                 rdivide(num + (n*j), den, num_div_den, n);
00259 
00260     //            opt.pred(:,t) = y(:,t) - (num(:,t)./den);
00261                 copy(pred->getData()+(n*j), Y.getData() + (n*j), n);
00262                 axpy(n, (T)-1.0, num_div_den, 1, pred->getData()+(n*j), 1);
00263             }
00264 
00265     //        opt.perf = opt.hoperf([],y,opt);
00266             const gMat2D<T> dummy;
00267             GurlsOptionsList* perf = perfClass->execute(dummy, Y, *nestedOpt);
00268 
00269             gMat2D<T> &forho_vec = perf->getOptValue<OptMatrix<gMat2D<T> > >("forho");
00270 
00271     //        for t = 1:T
00272             copy(ap+s, forho_vec.getData(), forho_vec.getSize(), tot, 1);
00273 
00274             delete perf;
00275         }
00276 
00277 
00278         delete perfClass;
00279 
00280         delete[] row;
00281         garbage.erase(row);
00282         delete[] num;
00283         garbage.erase(num);
00284         delete[] tmp;
00285         garbage.erase(tmp);
00286         delete[] tmp1;
00287         garbage.erase(tmp1);
00288         delete[] tmpvec;
00289         garbage.erase(tmpvec);
00290         delete[] LL;
00291         garbage.erase(LL);
00292         delete [] num_div_den;
00293         garbage.erase(num_div_den);
00294 
00295         delete nestedOpt;
00296         delete[] L;
00297         garbage.erase(L);
00298         delete[] LEFT;
00299         garbage.erase(LEFT);
00300         delete[] RIGHT;
00301         garbage.erase(RIGHT);
00302         delete[] right;
00303         garbage.erase(right);
00304         delete[] den;
00305         garbage.erase(den);
00306         delete [] LOOSQE;
00307         garbage.erase(LOOSQE);
00308 //        delete[] Le;
00309 //        garbage.erase(Le);
00310 
00311         //[dummy,idx] = max(ap,[],1);
00312         unsigned long* idx = new unsigned long[t];
00313         T* work = NULL;
00314         indicesOfMax(ap, tot, t, idx, work, 1);
00315 
00316         //vout.lambdas =    guesses(idx);
00317         gMat2D<T> *LAMBDA = new gMat2D<T>(1, t);
00318         copyLocations(idx, guesses, t, tot, LAMBDA->getData());
00319 
00320         delete[] idx;
00321 
00322         GurlsOptionsList* paramsel;
00323 
00324         if(opt.hasOpt("paramsel"))
00325         {
00326             GurlsOptionsList* tmp_opt = new GurlsOptionsList("tmp");
00327             tmp_opt->copyOpt("paramsel", opt);
00328 
00329             paramsel = GurlsOptionsList::dynacast(tmp_opt->getOpt("paramsel"));
00330             tmp_opt->removeOpt("paramsel", false);
00331             delete tmp_opt;
00332 
00333             paramsel->removeOpt("guesses");
00334             paramsel->removeOpt("perf");
00335             paramsel->removeOpt("lambdas");
00336         }
00337         else
00338             paramsel = new GurlsOptionsList("paramsel");
00339 
00340 
00341 
00342         paramsel->addOpt("lambdas", new OptMatrix<gMat2D<T> >(*LAMBDA));
00343         paramsel->addOpt("perf", new OptMatrix<gMat2D<T> >(*perf));
00344 
00345         //vout.guesses =    guesses;
00346         gMat2D<T> *guesses_mat = new gMat2D<T>(guesses, 1, tot, true);
00347         paramsel->addOpt("guesses", new OptMatrix<gMat2D<T> >(*guesses_mat));
00348 
00349         delete[] guesses;
00350 
00351         return paramsel;
00352     }
00353     catch( gException& e)
00354     {
00355 
00356         for(typename std::set<T*>::iterator it = garbage.begin(); it != garbage.end(); ++it)
00357             delete[] (*it);
00358 
00359         throw e;
00360     }
00361 }
00362 
00363 }
00364 
00365 #endif // _GURLS_LOOCVPRIMAL_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends