GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
rlsgp.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_RLSGP_H_
00044 #define _GURLS_RLSGP_H_
00045 
00046 #include <cmath>
00047 
00048 #include "gurls++/optimization.h"
00049 
00050 #include "gurls++/gmath.h"
00051 #include "gurls++/gmat2d.h"
00052 #include "gurls++/options.h"
00053 #include "gurls++/optlist.h"
00054 
00055 namespace gurls {
00056 
00062 template <typename T>
00063 class RLSGPRegr: public Optimizer<T>{
00064 
00065 public:
00083     GurlsOptionsList *execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList &opt);
00084 };
00085 
00086 
00087 template <typename T>
00088 GurlsOptionsList* RLSGPRegr<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt)
00089 {
00090     //    noise = opt.singlelambda(opt.paramsel.lambdas);
00091     const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas");
00092     T noiselevel = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize());
00093 
00094 
00095     const gMat2D<T> &K_mat = opt.getOptValue<OptMatrix<gMat2D<T> > >("kernel.K");
00096 
00097     T* K = new T[K_mat.getSize()];
00098     copy(K, K_mat.getData(), K_mat.getSize());
00099 
00100     //n = size(opt.kernel.K,1);
00101     const unsigned long n = K_mat.rows();
00102 
00103     //T = size(y,2);
00104     const unsigned long t = Y.cols();
00105 
00106 
00107     //    cfr.L = chol(opt.kernel.K + noise^2*eye(n));
00108     const T coeff = std::pow(noiselevel, 2);
00109     unsigned long i=0;
00110     for(T* it = K; i<n; ++i, it += n+1)
00111         *it += coeff;
00112 
00113     T* retL = new T[n*n];
00114     cholesky(K, n, n, retL);
00115 
00116     //    cfr.alpha = cfr.L\(cfr.L'\y);
00117     gMat2D<T>* alpha = new gMat2D<T>(n, t);
00118     copy(alpha->getData(), Y.getData(), Y.getSize());
00119 
00120     mldivide_squared(retL, alpha->getData(), n, n, n, t, CblasTrans);
00121     mldivide_squared(retL, alpha->getData(), n, n, n, t, CblasNoTrans);
00122 
00123 
00124     GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer");
00125 
00126 //           optimizer.L = L;
00127     gMat2D<T>* L = new gMat2D<T>(n, n);
00128     copy(L->getData(), retL, L->getSize());
00129     optimizer->addOpt("L", new OptMatrix<gMat2D<T> >(*L));
00130 
00131     delete[] retL;
00132 
00133 //           optimizer.alpha = alpha;
00134     optimizer->addOpt("alpha", new OptMatrix<gMat2D<T> >(*alpha));
00135 
00136 //    cfr.X = X;
00137     gMat2D<T>* optX = new gMat2D<T>(X);
00138     optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*optX));
00139 
00140     return optimizer;
00141 }
00142 
00143 }
00144 #endif // _GURLS_RLSGP_H_
00145 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends