GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
rlsdual.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_RLSDUAL_H_
00044 #define _GURLS_RLSDUAL_H_
00045 
00046 #include "gurls++/optimization.h"
00047 
00048 #include <set>
00049 
00050 namespace gurls {
00051 
00057 template <typename T>
00058 class RLSDual: public Optimizer<T>{
00059 
00060 public:
00078    GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt);
00079 };
00080 
00081 
00082 template <typename T>
00083 GurlsOptionsList* RLSDual<T>::execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt)
00084 {
00085    //   lambda = opt.singlelambda(opt.paramsel.lambdas);
00086    const gMat2D<T> &ll = opt.getOptValue<OptMatrix<gMat2D<T> > >("paramsel.lambdas");
00087    T lambda = opt.getOptAs<OptFunction>("singlelambda")->getValue(ll.getData(), ll.getSize());
00088 
00089    const GurlsOptionsList* kernel = opt.getOptAs<GurlsOptionsList>("kernel");
00090    const gMat2D<T>& K_mat = kernel->getOptValue<OptMatrix<gMat2D<T> > >("K");
00091 
00092    T* K = new T[K_mat.getSize()];
00093    copy(K, K_mat.getData(), K_mat.getSize());
00094 
00095     //n = size(opt.kernel.K,1);
00096    const long n = K_mat.rows();
00097 
00098    //T = size(y,2);
00099    const long t = Y.cols();
00100 
00101 
00102 //    std::cout << "Solving dual RLS... " << std::endl;
00103 
00104     const T coeff = n*static_cast<T>(lambda);
00105     long i=0;
00106     for(T* it = K; i<n; ++i, it+=n+1)
00107         *it += coeff;
00108 
00109 
00110    std::set<T*> garbage;
00111 
00112    gMat2D<T>* retC = NULL;
00113 
00114    try // Try solving it with cholesky first.
00115    {
00116 //        R = chol(K);
00117         T* R = new T[n*n];
00118         garbage.insert(R);
00119         cholesky(K, n, n, R);
00120 
00121 //        cfr.C = R\(R'\y);
00122         retC = new gMat2D<T>(Y.rows(), t);
00123 
00124         copy(retC->getData(), Y.getData(), Y.getSize());
00125         mldivide_squared(R, retC->getData(), n, n, retC->rows(), retC->cols(), CblasTrans);
00126         mldivide_squared(R, retC->getData(), n, n, retC->rows(), retC->cols(), CblasNoTrans);
00127 
00128         delete[] R;
00129         garbage.erase(R);
00130    }
00131    catch (gException& /*gex*/)
00132    {
00133        for(typename std::set<T*>::iterator it = garbage.begin(); it != garbage.end(); ++it)
00134            delete[] (*it);
00135 
00136        garbage.clear();
00137 
00138        if(retC != NULL)
00139            delete retC;
00140 
00141 
00142 //           [Q,L,V] = svd(K);
00143 //           Q = double(Q);
00144 //           L = double(diag(L));
00145        T *Q, *L, *V;
00146        int Q_rows, Q_cols;
00147        int L_len;
00148        int V_rows, V_cols;
00149 
00150        svd(K, Q, L, V, n, n, Q_rows, Q_cols, L_len, V_rows, V_cols);
00151 
00152 
00153 //           cfr.C = rls_eigen(Q,L,y,lambda,n);
00154        retC = new gMat2D<T>(Q_rows, Y.cols());
00155 
00156        T* Qty = new T[Q_cols*Y.cols()];
00157        dot(Q, Y.getData(), Qty, Q_rows, Q_cols, Y.rows(), Y.cols(), Q_cols, Y.cols(), CblasTrans, CblasNoTrans, CblasColMajor);
00158 
00159        T* work = new T[L_len*(Q_rows+1)];
00160        rls_eigen(Q, L, Qty, retC->getData(), lambda, n, Q_rows, Q_cols, L_len, Q_cols, Y.cols(), work);
00161 
00162        delete [] work;
00163        delete [] Qty;
00164        delete [] Q;
00165        delete [] L;
00166        delete [] V;
00167    }
00168 
00169    delete[] K;
00170 
00171    GurlsOptionsList* optimizer = new GurlsOptionsList("optimizer");
00172 
00173 //       if strcmp(opt.kernel.type, 'linear')
00174    if(kernel->getOptAsString("type") == "linear")
00175    {
00176 //           cfr.W = X'*cfr.C;
00177        gMat2D<T>* W  = new gMat2D<T>(X.cols(), retC->cols());
00178        dot(X.getData(), retC->getData(), W->getData(), X.rows(), X.cols(), retC->rows(), retC->cols(), W->rows(), W->cols(), CblasTrans, CblasNoTrans, CblasColMajor);
00179        optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*W));
00180 
00181 //           cfr.C = [];
00182        gMat2D<T>* emptyC = new gMat2D<T>();
00183        optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*emptyC));
00184 
00185 //           cfr.X = [];
00186        gMat2D<T>* emptyX = new gMat2D<T>();
00187        optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*emptyX));
00188 
00189        delete retC;
00190    }
00191    else
00192    {
00193 //           cfr.W = [];
00194        gMat2D<T>* emptyW = new gMat2D<T>();
00195        optimizer->addOpt("W", new OptMatrix<gMat2D<T> >(*emptyW));
00196 
00197 //           cfr.C = retC;
00198        optimizer->addOpt("C", new OptMatrix<gMat2D<T> >(*retC));
00199 
00200 //           cfr.X = X;
00201        gMat2D<T>* optX = new gMat2D<T>(X);
00202        optimizer->addOpt("X", new OptMatrix<gMat2D<T> >(*optX));
00203    }
00204 
00205     return optimizer;
00206 }
00207 
00208 }
00209 #endif // _GURLS_RLSDUAL_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends