![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, Matteo Santoro 00005 * All rights reserved. 00006 * 00007 * author: M. Santoro 00008 * email: matteo.santoro@gmail.com 00009 * 00010 * Redistribution and use in source and binary forms, with or without 00011 * modification, are permitted provided that the following conditions 00012 * are met: 00013 * 00014 * * Redistributions of source code must retain the above 00015 * copyright notice, this list of conditions and the following 00016 * disclaimer. 00017 * * Redistributions in binary form must reproduce the above 00018 * copyright notice, this list of conditions and the following 00019 * disclaimer in the documentation and/or other materials 00020 * provided with the distribution. 00021 * * Neither the name(s) of the copyright holders nor the names 00022 * of its contributors or of the Massacusetts Institute of 00023 * Technology or of the Italian Institute of Technology may be 00024 * used to endorse or promote products derived from this software 00025 * without specific prior written permission. 00026 * 00027 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00028 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00029 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00030 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00031 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00032 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00033 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00034 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00035 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00036 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00037 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00038 * POSSIBILITY OF SUCH DAMAGE. 00039 */ 00040 00041 #ifndef _GURLS_PREDGP_H 00042 #define _GURLS_PREDGP_H 00043 00044 #include <cmath> 00045 00046 #include "gurls++/pred.h" 00047 00048 #include "gurls++/gmath.h" 00049 #include "gurls++/gmat2d.h" 00050 #include "gurls++/options.h" 00051 #include "gurls++/optlist.h" 00052 00053 00054 namespace gurls { 00055 00061 template <typename T> 00062 class PredGPRegr: public Prediction<T> { 00063 00064 public: 00078 GurlsOptionsList *execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt); 00079 }; 00080 00081 template <typename T> 00082 GurlsOptionsList *PredGPRegr<T>::execute(const gMat2D<T>& X, const gMat2D<T>& /*Y*/, const GurlsOptionsList &opt) 00083 { 00084 // pred.means = opt.predkernel.K*opt.rls.alpha; 00085 00086 const GurlsOptionsList* predkernel = opt.getOptAs<GurlsOptionsList>("predkernel"); 00087 00088 const gMat2D<T> &K = predkernel->getOptValue<OptMatrix<gMat2D<T> > >("K"); 00089 00090 const unsigned long kr = K.rows(); 00091 const unsigned long kc = K.cols(); 00092 00093 00094 const GurlsOptionsList* rls = opt.getOptAs<GurlsOptionsList>("optimizer"); 00095 00096 const gMat2D<T> &L = rls->getOptValue<OptMatrix<gMat2D<T> > >("L"); 00097 00098 const unsigned long lr = L.rows(); 00099 const unsigned long lc = L.cols(); 00100 00101 const gMat2D<T> &alpha = rls->getOptValue<OptMatrix<gMat2D<T> > >("alpha"); 00102 00103 00104 gMat2D<T>* means_mat = new gMat2D<T>(kr, alpha.cols()); 00105 dot(K.getData(), alpha.getData(), means_mat->getData(), kr, kc, alpha.rows(), alpha.cols(), kr, alpha.cols(), CblasNoTrans, CblasNoTrans, CblasColMajor); 00106 00107 00108 const unsigned long n = X.rows(); 00109 00110 // pred.vars = zeros(n,1); 00111 gMat2D<T> *vars_mat = new gMat2D<T>(n, 1); 00112 T* vars = vars_mat->getData(); 00113 00114 T* v = new T[std::max(kc, n)]; 00115 00116 for(unsigned long i = 0; i<n; ++i) 00117 { 00118 getRow(K.getData(), kr, kc, i, v); 00119 00121 mldivide_squared(L.getData(), v, lr, lc, kc, 1, CblasTrans); 00122 00124 vars[i] = dot(kc, v, 1, v, 1); 00125 } 00126 00127 00128 // pred.vars = opt.predkernel.Ktest - pred.vars; 00129 const gMat2D<T> &Ktest = predkernel->getOptValue<OptMatrix<gMat2D<T> > >("Ktest"); 00130 copy(v, Ktest.getData(), n); 00131 axpy(n, (T)-1.0, vars, 1, v, 1); 00132 copy(vars, v, n); 00133 00134 delete[] v; 00135 00136 GurlsOptionsList* pred = new GurlsOptionsList("pred"); 00137 00138 pred->addOpt("means", new OptMatrix<gMat2D<T> >(*means_mat)); 00139 pred->addOpt("vars", new OptMatrix<gMat2D<T> >(*vars_mat)); 00140 00141 return pred; 00142 } 00143 00144 } 00145 00146 #endif // _GURLS_PREDGP_H