GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
gmath.cpp
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * author:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 #include "gurls++/gmath.h"
00043 
00044 #include "gurls++/gmat2d.h"
00045 #include "gurls++/gvec.h"
00046 #include "gurls++/exports.h"
00047 
00048 namespace gurls {
00049 
00053 template <>
00054 GURLS_EXPORT float dot(const gVec<float>& x, const gVec<float>& y)
00055 {
00056     if ( x.getSize() != y.getSize() )
00057         throw gException(gurls::Exception_Inconsistent_Size);
00058 
00059     int n = x.getSize();
00060     int incr = 1;
00061 
00062     return sdot_(&n, const_cast<float*>(x.getData()), &incr, const_cast<float*>(y.getData()), &incr);
00063 }
00064 
00068 template <>
00069 GURLS_EXPORT double dot(const gVec<double>& x, const gVec<double>& y)
00070 {
00071 
00072     if ( x.getSize() != y.getSize() )
00073         throw gException(gurls::Exception_Inconsistent_Size);
00074 
00075     int n = x.getSize();
00076     int incr = 1;
00077 
00078     return ddot_(&n, const_cast<double*>(x.getData()), &incr, const_cast<double*>(y.getData()), &incr);
00079 }
00080 
00081 
00082 // ============ OUT-OF-PLACE MATRIX MULTIPLICATION ==================
00086 template <>
00087 GURLS_EXPORT void dot(const gMat2D<float>& A, const gMat2D<float>& B, gMat2D<float>& C)
00088 {
00089 
00090     dot(A.getData(), B.getData(), C.getData(),
00091         A.rows(), A.cols(),
00092         B.rows(), B.cols(),
00093         C.rows(), C.cols(),
00094         CblasNoTrans, CblasNoTrans, CblasColMajor);
00095 }
00096 
00100 template <>
00101 GURLS_EXPORT void dot(const gMat2D<double>& A, const gMat2D<double>& B, gMat2D<double>& C)
00102 {
00103 
00104     dot(A.getData(), B.getData(), C.getData(),
00105         A.rows(), A.cols(),
00106         B.rows(), B.cols(),
00107         C.rows(), C.cols(),
00108         CblasNoTrans, CblasNoTrans, CblasColMajor);
00109 }
00110 
00111 
00112 // ============ OUT-OF-PLACE MATRIX-VECTOR MULTIPLICATION ==================
00113 
00117 template <>
00118 GURLS_EXPORT void dot(const gMat2D<float>& A, const gVec<float>& x, gVec<float>& y)
00119 {
00120     if ( (A.cols() != x.getSize()) ||  (A.rows() != y.getSize()))
00121         throw gException(Exception_Inconsistent_Size);
00122 
00123 
00124     // y = alpha*A*x + beta*y
00125     float alpha = 1.0f;
00126     float beta = 0.0f;
00127 
00128     char transA = 'N';
00129 
00130     int m = A.rows();
00131     int n = A.cols();
00132     int lda = m;
00133     int inc = 1;
00134 
00135     sgemv_(&transA, &m, &n, &alpha, const_cast<float*>(A.getData()), &lda,
00136           const_cast<float*>(x.getData()), &inc, &beta, y.getData(), &inc);
00137 }
00138 
00142 template <>
00143 GURLS_EXPORT void dot(const gMat2D<double>& A, const gVec<double>& x, gVec<double>& y){
00144 
00145     if ( (A.cols() != x.getSize()) ||  (A.rows() != y.getSize()))
00146         throw gException(Exception_Inconsistent_Size);
00147 
00148 
00149     // y = alpha*A*x + beta*y
00150     double alpha = 1.0;
00151     double beta = 0.0;
00152 
00153     char transA = 'N';
00154 
00155     int m = A.rows();
00156     int n = A.cols();
00157     int lda = m;
00158     int inc = 1;
00159 
00160     dgemv_(&transA, &m, &n, &alpha, const_cast<double*>(A.getData()), &lda,
00161           const_cast<double*>(x.getData()), &inc, &beta,
00162           y.getData(), &inc);
00163 
00164 }
00165 
00169 template <>
00170 GURLS_EXPORT void lu(gMat2D<float>& A, gVec<int>& pv)
00171 {
00172     unsigned int k = std::min(A.cols(), A.rows());
00173 
00174     if (pv.getSize() != k)
00175         throw gException("The lenghth of pv must be equal to the minimun dimension of A");
00176 
00177     int info;
00178     int m = A.rows();
00179     int n = A.cols();
00180     int lda = A.rows();
00181 
00182     sgetrf_(&m, &n, A.getData(), &lda, pv.getData(), &info);
00183 
00184     if(info <0)
00185         throw gException("LU factorization failed");
00186 }
00187 
00191 template <>
00192 GURLS_EXPORT void lu(gMat2D<float>& A)
00193 {
00194     gVec<int> pv(std::min(A.cols(), A.rows()));
00195     lu(A, pv);
00196 }
00197 
00201 template <>
00202 GURLS_EXPORT void inv(const gMat2D<float>& A, gMat2D<float>& Ainv, InversionAlgorithm alg)
00203 {
00204     Ainv = A;
00205     int k = std::min(Ainv.cols(), Ainv.rows());
00206 
00207     int info;
00208     int* ipiv = new int[k];
00209 
00210     int m = Ainv.rows();
00211     int n = Ainv.cols();
00212     int lda = Ainv.rows();
00213 
00214     sgetrf_(&m, &n, Ainv.getData(), &lda, ipiv, &info);
00215 
00216     float* work = new float[n];
00217 
00218     sgetri_(&m, Ainv.getData(), &lda, ipiv, work, &n, &info);
00219 
00220     delete[] ipiv;
00221     delete[] work;
00222 }
00223 
00227 template <>
00228 GURLS_EXPORT void pinv(const gMat2D<float>& A, gMat2D<float>& Ainv, float RCOND)
00229 {
00230     int r, c;
00231     float* inv = pinv(A.getData(), A.rows(), A.cols(), r, c, &RCOND);
00232 
00233     Ainv.resize(r, c);
00234     gurls::copy(Ainv.getData(), inv, r*c);
00235 
00236     delete[] inv;
00237 }
00238 
00242 template <>
00243 GURLS_EXPORT void svd(const gMat2D<float>& A, gMat2D<float>& U, gVec<float>& W, gMat2D<float>& Vt)
00244 {
00245     float* Ubuf;
00246     float* Sbuf;
00247     float* Vtbuf;
00248 
00249     int Urows, Ucols;
00250     int Slen;
00251     int Vtrows, Vtcols;
00252 
00253     gurls::svd(A.getData(), Ubuf, Sbuf, Vtbuf,
00254              A.rows(), A.cols(),
00255              Urows, Ucols, Slen, Vtrows, Vtcols);
00256 
00257 
00258     U.resize(Urows, Ucols);
00259     copy(U.getData(), Ubuf, U.getSize());
00260 
00261     W.resize(Slen);
00262     copy(W.getData(), Sbuf, Slen);
00263 
00264     Vt.resize(Vtrows, Vtcols);
00265     copy(Vt.getData(), Vtbuf, Vt.getSize());
00266 
00267     delete [] Ubuf;
00268     delete [] Sbuf;
00269     delete [] Vtbuf;
00270 }
00271 
00275 template <>
00276 GURLS_EXPORT void eig(const gMat2D<float>& A, gMat2D<float>& V, gVec<float>& Wr, gVec<float>& Wi)
00277 {
00278     if (A.cols() != A.rows())
00279         throw gException("The input matrix A must be squared");
00280 
00281     float* Atmp = new float[A.getSize()];
00282     copy(Atmp, A.getData(), A.getSize());
00283 
00284     char jobvl = 'N', jobvr = 'V';
00285     int n = A.cols(), lda = A.cols(), ldvl = 1, ldvr = A.cols();
00286     int info, lwork = 4*n;
00287     float* work = new float[lwork];
00288 
00289     sgeev_(&jobvl, &jobvr, &n, Atmp, &lda, Wr.getData(), Wi.getData(), NULL, &ldvl, V.getData(), &ldvr, work, &lwork, &info);
00290 
00291     delete[] Atmp;
00292     delete[] work;
00293 
00294     if(info != 0)
00295     {
00296         std::stringstream str;
00297         str << "Eigenvalues/eigenVectors computation failed, error code " << info << ";" << std::endl;
00298         throw gException(str.str());
00299     }
00300 }
00301 
00305 template <>
00306 GURLS_EXPORT void eig(const gMat2D<float>& A, gMat2D<float>& V, gVec<float>& W)
00307 {
00308     gVec<float> tmp(W.getSize());
00309     tmp = 0;
00310 
00311     eig(A, V, W, tmp);
00312 }
00313 
00317 template <>
00318 GURLS_EXPORT void eig(const gMat2D<float>& A, gVec<float>& Wr, gVec<float>& Wi)
00319 {
00320     if (A.cols() != A.rows())
00321         throw gException("The input matrix A must be squared");
00322 
00323     float* Atmp = new float[A.getSize()];
00324     copy(Atmp, A.getData(), A.getSize());
00325 
00326     char jobvl = 'N', jobvr = 'N';
00327     int n = A.cols(), lda = A.cols(), ldvl = 1, ldvr = 1;
00328     int info, lwork = 4*n;
00329     float* work = new float[lwork];
00330 
00331     sgeev_(&jobvl, &jobvr, &n, Atmp, &lda, Wr.getData(), Wi.getData(), NULL, &ldvl, NULL, &ldvr, work, &lwork, &info);
00332 
00333     delete[] Atmp;
00334     delete[] work;
00335 
00336     if(info != 0)
00337     {
00338         std::stringstream str;
00339         str << "Eigenvalues/eigenVectors computation failed, error code " << info << ";" << std::endl;
00340         throw gException(str.str());
00341     }
00342 }
00343 
00347 template <>
00348 GURLS_EXPORT void eig(const gMat2D<float>& A, gVec<float>& W)
00349 {
00350     gVec<float> tmp = W;
00351     eig(A, W, tmp);
00352 }
00353 
00354 
00358 template <>
00359 GURLS_EXPORT void cholesky(const gMat2D<float>& A, gMat2D<float>& L, bool upper)
00360 {
00361     cholesky<float>(A.getData(), A.rows(), A.cols(), L.getData(), upper);
00362 }
00363 
00367 template<>
00368 void GURLS_EXPORT set(float* buffer, const float value, const int size, const int incr)
00369 {
00370     int incx = 0;
00371 
00372     scopy_(const_cast<int*>(&size), const_cast<float*>(&value), &incx, buffer, const_cast<int*>(&incr));
00373 }
00374 
00378 template<>
00379 void GURLS_EXPORT set(float* buffer, const float value, const int size)
00380 {
00381     set<float>(buffer, value, size, 1);
00382 }
00383 
00387 template<>
00388 void GURLS_EXPORT set(double* buffer, const double value, const int size, const int incr)
00389 {
00390     int incx = 0;
00391 
00392     dcopy_(const_cast<int*>(&size), const_cast<double*>(&value), &incx, buffer, const_cast<int*>(&incr));
00393 
00394 }
00395 
00399 template<>
00400 void GURLS_EXPORT set(double* buffer, const double value, const int size)
00401 {
00402     set<double>(buffer, value, size, 1);
00403 }
00404 
00408 template<>
00409 void GURLS_EXPORT copy(float* dst, const float* src, const int size, const int dstIncr, const int srcIncr)
00410 {
00411     scopy_(const_cast<int*>(&size), const_cast<float*>(src), const_cast<int*>(&srcIncr), dst, const_cast<int*>(&dstIncr));
00412 }
00413 
00417 template<>
00418 void GURLS_EXPORT copy(float* dst, const float* src, const int size)
00419 {
00420     int incr = 1;
00421 
00422     scopy_(const_cast<int*>(&size), const_cast<float*>(src), &incr, dst, &incr);
00423 }
00424 
00428 template<>
00429 void GURLS_EXPORT copy(double* dst, const double* src, const int size, const int dstIncr, const int srcIncr)
00430 {
00431     dcopy_(const_cast<int*>(&size), const_cast<double*>(src), const_cast<int*>(&srcIncr), dst, const_cast<int*>(&dstIncr));
00432 }
00433 
00437 template<>
00438 void GURLS_EXPORT copy(double* dst, const double* src, const int size)
00439 {
00440     int incr = 1;
00441 
00442     dcopy_(const_cast<int*>(&size), const_cast<double*>(src), &incr, dst, &incr);
00443 }
00444 
00446 //  * Specialized version of pinv for float buffers
00447 //  */
00448 //template<>
00449 //GURLS_EXPORT float* pinv(const float* A, int rows, int cols, int& res_rows, int& res_cols, float* RCOND)
00450 //{
00451 //    int M = rows;
00452 //    int N = cols;
00453 
00454 //    float* a = new float[rows*cols];
00455 //    copy<float>(a, A, rows*cols);
00456 
00457 //    int LDA = M;
00458 //    int LDB = std::max(M, N);
00459 //    int NRHS = LDB;
00460 
00461 
00462 //    // float* b = eye(LDB).getData()
00463 
00464 //    const int b_size = LDB*NRHS;
00465 //    float *b = new float[LDB*NRHS];
00466 //    set<float>(b, 0.f, b_size);
00467 //    set<float>(b, 1.f, std::min(LDB, NRHS), NRHS+1);
00468 
00469 //    float* S = new float[std::min(M,N)];
00471 
00472 //    float rcond = (RCOND == NULL)? (std::max(rows, cols)*FLT_EPSILON): *RCOND;
00473 
00476 
00477 //    int RANK = -1; // std::min(M,N);
00478 //    int LWORK = -1; //2 * (3*LDB + std::max( 2*std::min(M,N), LDB));
00479 //    float* WORK = new float[1];
00480 
00481 //    /*
00482 
00483 //    subroutine SGELSS     (   INTEGER     M,
00484 //      INTEGER     N,
00485 //      INTEGER     NRHS,
00486 //      REAL,dimension( lda, * )    A,
00487 //      INTEGER     LDA,
00488 //      REAL,dimension( ldb, * )    B,
00489 //      INTEGER     LDB,
00490 //      REAL,dimension( * )     S,
00491 //      REAL    RCOND,
00492 //      INTEGER     RANK,
00493 //      REAL,dimension( * )     WORK,
00494 //      INTEGER     LWORK,
00495 //      INTEGER     INFO
00496 //     )
00497 
00498 //    */
00499 
00500 //    /*
00501 //   INFO:
00502 //   = 0:   successful exit
00503 //   < 0:   if INFO = -i, the i-th argument had an illegal value.
00504 //   > 0:   the algorithm for computing the SVD failed to converge;
00505 //     if INFO = i, i off-diagonal elements of an intermediate
00506 //     bidiagonal form did not converge to zero.
00507 //   */
00508 //    int INFO;
00509 
00510 //    /* Query and allocate the optimal workspace */
00511 //    sgelss_( &M, &N, &NRHS, a, &LDA, b, &LDB, S, &rcond, &RANK, WORK, &LWORK, &INFO);
00512 //    LWORK = static_cast<int>(WORK[0]);
00513 //    delete [] WORK;
00514 //    WORK = new float[LWORK];
00515 
00516 //    sgelss_( &M, &N, &NRHS, a, &LDA, b, &LDB, S, &rcond, &RANK, WORK, &LWORK, &INFO);
00517 
00518 //    // TODO: check INFO on exit
00519 //    //condnum = S[0]/(S[std::min(M, N)]-1);
00520 
00521 //    if(INFO != 0)
00522 //    {
00523 //        std::stringstream str;
00524 //        str << "Pinv failed, error code " << INFO << ";" << std::endl;
00525 //        throw gException(str.str());
00526 //    }
00527 
00528 //    delete [] S;
00529 //    delete [] WORK;
00530 //    delete [] a;
00531 
00532 //    res_rows = LDB;
00533 //    res_cols = NRHS;
00534 //    return b;
00535 
00536 //}
00537 
00541 template<>
00542 GURLS_EXPORT bool eq(double val1, double val2)
00543 {
00544     return (val1 >= val2-DBL_EPSILON && val1 <= val2+DBL_EPSILON );
00545 }
00546 
00550 template<>
00551 GURLS_EXPORT bool eq(float val1, float val2)
00552 {
00553     return ( val1 >= val2-FLT_EPSILON && val1 <= val2+FLT_EPSILON );
00554 }
00555 
00559 template<>
00560 GURLS_EXPORT bool gt(double a, double b)
00561 {
00562     return ((a - b) > ( std::min(fabs(a), fabs(b))* std::numeric_limits<double>::epsilon()));
00563 }
00564 
00568 template<>
00569 GURLS_EXPORT bool gt(float a, float b)
00570 {
00571     return ((a - b) > ( std::min(fabs(a), fabs(b))* std::numeric_limits<float>::epsilon()));
00572 }
00573 
00577 template<>
00578 GURLS_EXPORT bool lt(double a, double b)
00579 {
00580     return ((b - a) > ( std::max(fabs(a), fabs(b))* std::numeric_limits<double>::epsilon()));
00581 }
00582 
00586 template<>
00587 GURLS_EXPORT bool lt(float a, float b)
00588 {
00589     return ((b - a) > ( std::max(fabs(a), fabs(b))* std::numeric_limits<float>::epsilon()));
00590 }
00591 
00592 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends