![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * author: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 #include "gurls++/gmath.h" 00043 00044 #include "gurls++/gmat2d.h" 00045 #include "gurls++/gvec.h" 00046 #include "gurls++/exports.h" 00047 00048 namespace gurls { 00049 00053 template <> 00054 GURLS_EXPORT float dot(const gVec<float>& x, const gVec<float>& y) 00055 { 00056 if ( x.getSize() != y.getSize() ) 00057 throw gException(gurls::Exception_Inconsistent_Size); 00058 00059 int n = x.getSize(); 00060 int incr = 1; 00061 00062 return sdot_(&n, const_cast<float*>(x.getData()), &incr, const_cast<float*>(y.getData()), &incr); 00063 } 00064 00068 template <> 00069 GURLS_EXPORT double dot(const gVec<double>& x, const gVec<double>& y) 00070 { 00071 00072 if ( x.getSize() != y.getSize() ) 00073 throw gException(gurls::Exception_Inconsistent_Size); 00074 00075 int n = x.getSize(); 00076 int incr = 1; 00077 00078 return ddot_(&n, const_cast<double*>(x.getData()), &incr, const_cast<double*>(y.getData()), &incr); 00079 } 00080 00081 00082 // ============ OUT-OF-PLACE MATRIX MULTIPLICATION ================== 00086 template <> 00087 GURLS_EXPORT void dot(const gMat2D<float>& A, const gMat2D<float>& B, gMat2D<float>& C) 00088 { 00089 00090 dot(A.getData(), B.getData(), C.getData(), 00091 A.rows(), A.cols(), 00092 B.rows(), B.cols(), 00093 C.rows(), C.cols(), 00094 CblasNoTrans, CblasNoTrans, CblasColMajor); 00095 } 00096 00100 template <> 00101 GURLS_EXPORT void dot(const gMat2D<double>& A, const gMat2D<double>& B, gMat2D<double>& C) 00102 { 00103 00104 dot(A.getData(), B.getData(), C.getData(), 00105 A.rows(), A.cols(), 00106 B.rows(), B.cols(), 00107 C.rows(), C.cols(), 00108 CblasNoTrans, CblasNoTrans, CblasColMajor); 00109 } 00110 00111 00112 // ============ OUT-OF-PLACE MATRIX-VECTOR MULTIPLICATION ================== 00113 00117 template <> 00118 GURLS_EXPORT void dot(const gMat2D<float>& A, const gVec<float>& x, gVec<float>& y) 00119 { 00120 if ( (A.cols() != x.getSize()) || (A.rows() != y.getSize())) 00121 throw gException(Exception_Inconsistent_Size); 00122 00123 00124 // y = alpha*A*x + beta*y 00125 float alpha = 1.0f; 00126 float beta = 0.0f; 00127 00128 char transA = 'N'; 00129 00130 int m = A.rows(); 00131 int n = A.cols(); 00132 int lda = m; 00133 int inc = 1; 00134 00135 sgemv_(&transA, &m, &n, &alpha, const_cast<float*>(A.getData()), &lda, 00136 const_cast<float*>(x.getData()), &inc, &beta, y.getData(), &inc); 00137 } 00138 00142 template <> 00143 GURLS_EXPORT void dot(const gMat2D<double>& A, const gVec<double>& x, gVec<double>& y){ 00144 00145 if ( (A.cols() != x.getSize()) || (A.rows() != y.getSize())) 00146 throw gException(Exception_Inconsistent_Size); 00147 00148 00149 // y = alpha*A*x + beta*y 00150 double alpha = 1.0; 00151 double beta = 0.0; 00152 00153 char transA = 'N'; 00154 00155 int m = A.rows(); 00156 int n = A.cols(); 00157 int lda = m; 00158 int inc = 1; 00159 00160 dgemv_(&transA, &m, &n, &alpha, const_cast<double*>(A.getData()), &lda, 00161 const_cast<double*>(x.getData()), &inc, &beta, 00162 y.getData(), &inc); 00163 00164 } 00165 00169 template <> 00170 GURLS_EXPORT void lu(gMat2D<float>& A, gVec<int>& pv) 00171 { 00172 unsigned int k = std::min(A.cols(), A.rows()); 00173 00174 if (pv.getSize() != k) 00175 throw gException("The lenghth of pv must be equal to the minimun dimension of A"); 00176 00177 int info; 00178 int m = A.rows(); 00179 int n = A.cols(); 00180 int lda = A.rows(); 00181 00182 sgetrf_(&m, &n, A.getData(), &lda, pv.getData(), &info); 00183 00184 if(info <0) 00185 throw gException("LU factorization failed"); 00186 } 00187 00191 template <> 00192 GURLS_EXPORT void lu(gMat2D<float>& A) 00193 { 00194 gVec<int> pv(std::min(A.cols(), A.rows())); 00195 lu(A, pv); 00196 } 00197 00201 template <> 00202 GURLS_EXPORT void inv(const gMat2D<float>& A, gMat2D<float>& Ainv, InversionAlgorithm alg) 00203 { 00204 Ainv = A; 00205 int k = std::min(Ainv.cols(), Ainv.rows()); 00206 00207 int info; 00208 int* ipiv = new int[k]; 00209 00210 int m = Ainv.rows(); 00211 int n = Ainv.cols(); 00212 int lda = Ainv.rows(); 00213 00214 sgetrf_(&m, &n, Ainv.getData(), &lda, ipiv, &info); 00215 00216 float* work = new float[n]; 00217 00218 sgetri_(&m, Ainv.getData(), &lda, ipiv, work, &n, &info); 00219 00220 delete[] ipiv; 00221 delete[] work; 00222 } 00223 00227 template <> 00228 GURLS_EXPORT void pinv(const gMat2D<float>& A, gMat2D<float>& Ainv, float RCOND) 00229 { 00230 int r, c; 00231 float* inv = pinv(A.getData(), A.rows(), A.cols(), r, c, &RCOND); 00232 00233 Ainv.resize(r, c); 00234 gurls::copy(Ainv.getData(), inv, r*c); 00235 00236 delete[] inv; 00237 } 00238 00242 template <> 00243 GURLS_EXPORT void svd(const gMat2D<float>& A, gMat2D<float>& U, gVec<float>& W, gMat2D<float>& Vt) 00244 { 00245 float* Ubuf; 00246 float* Sbuf; 00247 float* Vtbuf; 00248 00249 int Urows, Ucols; 00250 int Slen; 00251 int Vtrows, Vtcols; 00252 00253 gurls::svd(A.getData(), Ubuf, Sbuf, Vtbuf, 00254 A.rows(), A.cols(), 00255 Urows, Ucols, Slen, Vtrows, Vtcols); 00256 00257 00258 U.resize(Urows, Ucols); 00259 copy(U.getData(), Ubuf, U.getSize()); 00260 00261 W.resize(Slen); 00262 copy(W.getData(), Sbuf, Slen); 00263 00264 Vt.resize(Vtrows, Vtcols); 00265 copy(Vt.getData(), Vtbuf, Vt.getSize()); 00266 00267 delete [] Ubuf; 00268 delete [] Sbuf; 00269 delete [] Vtbuf; 00270 } 00271 00275 template <> 00276 GURLS_EXPORT void eig(const gMat2D<float>& A, gMat2D<float>& V, gVec<float>& Wr, gVec<float>& Wi) 00277 { 00278 if (A.cols() != A.rows()) 00279 throw gException("The input matrix A must be squared"); 00280 00281 float* Atmp = new float[A.getSize()]; 00282 copy(Atmp, A.getData(), A.getSize()); 00283 00284 char jobvl = 'N', jobvr = 'V'; 00285 int n = A.cols(), lda = A.cols(), ldvl = 1, ldvr = A.cols(); 00286 int info, lwork = 4*n; 00287 float* work = new float[lwork]; 00288 00289 sgeev_(&jobvl, &jobvr, &n, Atmp, &lda, Wr.getData(), Wi.getData(), NULL, &ldvl, V.getData(), &ldvr, work, &lwork, &info); 00290 00291 delete[] Atmp; 00292 delete[] work; 00293 00294 if(info != 0) 00295 { 00296 std::stringstream str; 00297 str << "Eigenvalues/eigenVectors computation failed, error code " << info << ";" << std::endl; 00298 throw gException(str.str()); 00299 } 00300 } 00301 00305 template <> 00306 GURLS_EXPORT void eig(const gMat2D<float>& A, gMat2D<float>& V, gVec<float>& W) 00307 { 00308 gVec<float> tmp(W.getSize()); 00309 tmp = 0; 00310 00311 eig(A, V, W, tmp); 00312 } 00313 00317 template <> 00318 GURLS_EXPORT void eig(const gMat2D<float>& A, gVec<float>& Wr, gVec<float>& Wi) 00319 { 00320 if (A.cols() != A.rows()) 00321 throw gException("The input matrix A must be squared"); 00322 00323 float* Atmp = new float[A.getSize()]; 00324 copy(Atmp, A.getData(), A.getSize()); 00325 00326 char jobvl = 'N', jobvr = 'N'; 00327 int n = A.cols(), lda = A.cols(), ldvl = 1, ldvr = 1; 00328 int info, lwork = 4*n; 00329 float* work = new float[lwork]; 00330 00331 sgeev_(&jobvl, &jobvr, &n, Atmp, &lda, Wr.getData(), Wi.getData(), NULL, &ldvl, NULL, &ldvr, work, &lwork, &info); 00332 00333 delete[] Atmp; 00334 delete[] work; 00335 00336 if(info != 0) 00337 { 00338 std::stringstream str; 00339 str << "Eigenvalues/eigenVectors computation failed, error code " << info << ";" << std::endl; 00340 throw gException(str.str()); 00341 } 00342 } 00343 00347 template <> 00348 GURLS_EXPORT void eig(const gMat2D<float>& A, gVec<float>& W) 00349 { 00350 gVec<float> tmp = W; 00351 eig(A, W, tmp); 00352 } 00353 00354 00358 template <> 00359 GURLS_EXPORT void cholesky(const gMat2D<float>& A, gMat2D<float>& L, bool upper) 00360 { 00361 cholesky<float>(A.getData(), A.rows(), A.cols(), L.getData(), upper); 00362 } 00363 00367 template<> 00368 void GURLS_EXPORT set(float* buffer, const float value, const int size, const int incr) 00369 { 00370 int incx = 0; 00371 00372 scopy_(const_cast<int*>(&size), const_cast<float*>(&value), &incx, buffer, const_cast<int*>(&incr)); 00373 } 00374 00378 template<> 00379 void GURLS_EXPORT set(float* buffer, const float value, const int size) 00380 { 00381 set<float>(buffer, value, size, 1); 00382 } 00383 00387 template<> 00388 void GURLS_EXPORT set(double* buffer, const double value, const int size, const int incr) 00389 { 00390 int incx = 0; 00391 00392 dcopy_(const_cast<int*>(&size), const_cast<double*>(&value), &incx, buffer, const_cast<int*>(&incr)); 00393 00394 } 00395 00399 template<> 00400 void GURLS_EXPORT set(double* buffer, const double value, const int size) 00401 { 00402 set<double>(buffer, value, size, 1); 00403 } 00404 00408 template<> 00409 void GURLS_EXPORT copy(float* dst, const float* src, const int size, const int dstIncr, const int srcIncr) 00410 { 00411 scopy_(const_cast<int*>(&size), const_cast<float*>(src), const_cast<int*>(&srcIncr), dst, const_cast<int*>(&dstIncr)); 00412 } 00413 00417 template<> 00418 void GURLS_EXPORT copy(float* dst, const float* src, const int size) 00419 { 00420 int incr = 1; 00421 00422 scopy_(const_cast<int*>(&size), const_cast<float*>(src), &incr, dst, &incr); 00423 } 00424 00428 template<> 00429 void GURLS_EXPORT copy(double* dst, const double* src, const int size, const int dstIncr, const int srcIncr) 00430 { 00431 dcopy_(const_cast<int*>(&size), const_cast<double*>(src), const_cast<int*>(&srcIncr), dst, const_cast<int*>(&dstIncr)); 00432 } 00433 00437 template<> 00438 void GURLS_EXPORT copy(double* dst, const double* src, const int size) 00439 { 00440 int incr = 1; 00441 00442 dcopy_(const_cast<int*>(&size), const_cast<double*>(src), &incr, dst, &incr); 00443 } 00444 00446 // * Specialized version of pinv for float buffers 00447 // */ 00448 //template<> 00449 //GURLS_EXPORT float* pinv(const float* A, int rows, int cols, int& res_rows, int& res_cols, float* RCOND) 00450 //{ 00451 // int M = rows; 00452 // int N = cols; 00453 00454 // float* a = new float[rows*cols]; 00455 // copy<float>(a, A, rows*cols); 00456 00457 // int LDA = M; 00458 // int LDB = std::max(M, N); 00459 // int NRHS = LDB; 00460 00461 00462 // // float* b = eye(LDB).getData() 00463 00464 // const int b_size = LDB*NRHS; 00465 // float *b = new float[LDB*NRHS]; 00466 // set<float>(b, 0.f, b_size); 00467 // set<float>(b, 1.f, std::min(LDB, NRHS), NRHS+1); 00468 00469 // float* S = new float[std::min(M,N)]; 00471 00472 // float rcond = (RCOND == NULL)? (std::max(rows, cols)*FLT_EPSILON): *RCOND; 00473 00476 00477 // int RANK = -1; // std::min(M,N); 00478 // int LWORK = -1; //2 * (3*LDB + std::max( 2*std::min(M,N), LDB)); 00479 // float* WORK = new float[1]; 00480 00481 // /* 00482 00483 // subroutine SGELSS ( INTEGER M, 00484 // INTEGER N, 00485 // INTEGER NRHS, 00486 // REAL,dimension( lda, * ) A, 00487 // INTEGER LDA, 00488 // REAL,dimension( ldb, * ) B, 00489 // INTEGER LDB, 00490 // REAL,dimension( * ) S, 00491 // REAL RCOND, 00492 // INTEGER RANK, 00493 // REAL,dimension( * ) WORK, 00494 // INTEGER LWORK, 00495 // INTEGER INFO 00496 // ) 00497 00498 // */ 00499 00500 // /* 00501 // INFO: 00502 // = 0: successful exit 00503 // < 0: if INFO = -i, the i-th argument had an illegal value. 00504 // > 0: the algorithm for computing the SVD failed to converge; 00505 // if INFO = i, i off-diagonal elements of an intermediate 00506 // bidiagonal form did not converge to zero. 00507 // */ 00508 // int INFO; 00509 00510 // /* Query and allocate the optimal workspace */ 00511 // sgelss_( &M, &N, &NRHS, a, &LDA, b, &LDB, S, &rcond, &RANK, WORK, &LWORK, &INFO); 00512 // LWORK = static_cast<int>(WORK[0]); 00513 // delete [] WORK; 00514 // WORK = new float[LWORK]; 00515 00516 // sgelss_( &M, &N, &NRHS, a, &LDA, b, &LDB, S, &rcond, &RANK, WORK, &LWORK, &INFO); 00517 00518 // // TODO: check INFO on exit 00519 // //condnum = S[0]/(S[std::min(M, N)]-1); 00520 00521 // if(INFO != 0) 00522 // { 00523 // std::stringstream str; 00524 // str << "Pinv failed, error code " << INFO << ";" << std::endl; 00525 // throw gException(str.str()); 00526 // } 00527 00528 // delete [] S; 00529 // delete [] WORK; 00530 // delete [] a; 00531 00532 // res_rows = LDB; 00533 // res_cols = NRHS; 00534 // return b; 00535 00536 //} 00537 00541 template<> 00542 GURLS_EXPORT bool eq(double val1, double val2) 00543 { 00544 return (val1 >= val2-DBL_EPSILON && val1 <= val2+DBL_EPSILON ); 00545 } 00546 00550 template<> 00551 GURLS_EXPORT bool eq(float val1, float val2) 00552 { 00553 return ( val1 >= val2-FLT_EPSILON && val1 <= val2+FLT_EPSILON ); 00554 } 00555 00559 template<> 00560 GURLS_EXPORT bool gt(double a, double b) 00561 { 00562 return ((a - b) > ( std::min(fabs(a), fabs(b))* std::numeric_limits<double>::epsilon())); 00563 } 00564 00568 template<> 00569 GURLS_EXPORT bool gt(float a, float b) 00570 { 00571 return ((a - b) > ( std::min(fabs(a), fabs(b))* std::numeric_limits<float>::epsilon())); 00572 } 00573 00577 template<> 00578 GURLS_EXPORT bool lt(double a, double b) 00579 { 00580 return ((b - a) > ( std::max(fabs(a), fabs(b))* std::numeric_limits<double>::epsilon())); 00581 } 00582 00586 template<> 00587 GURLS_EXPORT bool lt(float a, float b) 00588 { 00589 return ((b - a) > ( std::max(fabs(a), fabs(b))* std::numeric_limits<float>::epsilon())); 00590 } 00591 00592 }