GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
predgp.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, Matteo Santoro
00005  * All rights reserved.
00006  *
00007  * author:  M. Santoro
00008  * email:   matteo.santoro@gmail.com
00009  *
00010  * Redistribution and use in source and binary forms, with or without
00011  * modification, are permitted provided that the following conditions
00012  * are met:
00013  *
00014  *     * Redistributions of source code must retain the above
00015  *       copyright notice, this list of conditions and the following
00016  *       disclaimer.
00017  *     * Redistributions in binary form must reproduce the above
00018  *       copyright notice, this list of conditions and the following
00019  *       disclaimer in the documentation and/or other materials
00020  *       provided with the distribution.
00021  *     * Neither the name(s) of the copyright holders nor the names
00022  *       of its contributors or of the Massacusetts Institute of
00023  *       Technology or of the Italian Institute of Technology may be
00024  *       used to endorse or promote products derived from this software
00025  *       without specific prior written permission.
00026  *
00027  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00028  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00029  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00030  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00031  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00032  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00033  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00034  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00035  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00036  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00037  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00038  * POSSIBILITY OF SUCH DAMAGE.
00039  */
00040 
00041 #ifndef _GURLS_PREDGP_H
00042 #define _GURLS_PREDGP_H
00043 
00044 #include <cmath>
00045 
00046 #include "gurls++/pred.h"
00047 
00048 #include "gurls++/gmath.h"
00049 #include "gurls++/gmat2d.h"
00050 #include "gurls++/options.h"
00051 #include "gurls++/optlist.h"
00052 
00053 
00054 namespace gurls {
00055 
00061 template <typename T>
00062 class PredGPRegr: public Prediction<T> {
00063 
00064 public:
00078     GurlsOptionsList *execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00079 };
00080 
00081 template <typename T>
00082 GurlsOptionsList *PredGPRegr<T>::execute(const gMat2D<T>& X, const gMat2D<T>& /*Y*/, const GurlsOptionsList &opt)
00083 {
00084 //    pred.means = opt.predkernel.K*opt.rls.alpha;
00085 
00086     const GurlsOptionsList* predkernel = opt.getOptAs<GurlsOptionsList>("predkernel");
00087 
00088     const gMat2D<T> &K = predkernel->getOptValue<OptMatrix<gMat2D<T> > >("K");
00089 
00090     const unsigned long kr = K.rows();
00091     const unsigned long kc = K.cols();
00092 
00093 
00094     const GurlsOptionsList* rls = opt.getOptAs<GurlsOptionsList>("optimizer");
00095 
00096     const gMat2D<T> &L = rls->getOptValue<OptMatrix<gMat2D<T> > >("L");
00097 
00098     const unsigned long lr = L.rows();
00099     const unsigned long lc = L.cols();
00100 
00101     const gMat2D<T> &alpha = rls->getOptValue<OptMatrix<gMat2D<T> > >("alpha");
00102 
00103 
00104     gMat2D<T>* means_mat = new gMat2D<T>(kr, alpha.cols());
00105     dot(K.getData(), alpha.getData(), means_mat->getData(), kr, kc, alpha.rows(), alpha.cols(), kr, alpha.cols(), CblasNoTrans, CblasNoTrans, CblasColMajor);
00106 
00107 
00108     const unsigned long n = X.rows();
00109 
00110 //    pred.vars = zeros(n,1);
00111     gMat2D<T> *vars_mat = new gMat2D<T>(n, 1);
00112     T* vars = vars_mat->getData();
00113 
00114     T* v = new T[std::max(kc, n)];
00115 
00116     for(unsigned long i = 0; i<n; ++i)
00117     {
00118         getRow(K.getData(), kr, kc, i, v);
00119 
00121         mldivide_squared(L.getData(), v, lr, lc, kc, 1, CblasTrans);
00122 
00124         vars[i] =  dot(kc, v, 1, v, 1);
00125     }
00126 
00127 
00128 //    pred.vars = opt.predkernel.Ktest - pred.vars;
00129     const gMat2D<T> &Ktest = predkernel->getOptValue<OptMatrix<gMat2D<T> > >("Ktest");
00130     copy(v, Ktest.getData(), n);
00131     axpy(n, (T)-1.0, vars, 1, v, 1);
00132     copy(vars, v, n);
00133 
00134     delete[] v;
00135 
00136     GurlsOptionsList* pred = new GurlsOptionsList("pred");
00137 
00138     pred->addOpt("means", new OptMatrix<gMat2D<T> >(*means_mat));
00139     pred->addOpt("vars", new OptMatrix<gMat2D<T> >(*vars_mat));
00140 
00141     return pred;
00142 }
00143 
00144 }
00145 
00146 #endif // _GURLS_PREDGP_H
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends