![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
00001 /* 00002 * The GURLS Package in C++ 00003 * 00004 * Copyright (C) 2011-1013, IIT@MIT Lab 00005 * All rights reserved. 00006 * 00007 * authors: M. Santoro 00008 * email: msantoro@mit.edu 00009 * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html 00010 * 00011 * Redistribution and use in source and binary forms, with or without 00012 * modification, are permitted provided that the following conditions 00013 * are met: 00014 * 00015 * * Redistributions of source code must retain the above 00016 * copyright notice, this list of conditions and the following 00017 * disclaimer. 00018 * * Redistributions in binary form must reproduce the above 00019 * copyright notice, this list of conditions and the following 00020 * disclaimer in the documentation and/or other materials 00021 * provided with the distribution. 00022 * * Neither the name(s) of the copyright holders nor the names 00023 * of its contributors or of the Massacusetts Institute of 00024 * Technology or of the Italian Institute of Technology may be 00025 * used to endorse or promote products derived from this software 00026 * without specific prior written permission. 00027 * 00028 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 00029 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 00030 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 00031 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 00032 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 00033 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 00034 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00035 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 00036 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 00037 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 00038 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00039 * POSSIBILITY OF SUCH DAMAGE. 00040 */ 00041 00042 00043 #ifndef _GURLS_SPLITHO_H_ 00044 #define _GURLS_SPLITHO_H_ 00045 00046 00047 #include "gurls++/split.h" 00048 #include "gurls++/gmath.h" 00049 00050 namespace gurls { 00051 00057 template <typename T> 00058 class SplitHo: public Split<T> 00059 { 00060 public: 00074 GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) throw(gException); 00075 }; 00076 00077 template<typename T> 00078 GurlsOptionsList *SplitHo<T>::execute(const gMat2D<T>& /*X*/, const gMat2D<T>& Y, const GurlsOptionsList &opt) throw(gException) 00079 { 00080 // nSplits = opt.nholdouts; 00081 const int nSplits = static_cast<int>(opt.getOptAsNumber("nholdouts")); 00082 00083 // fraction = opt.hoproportion; 00084 const double fraction = opt.getOptAsNumber("hoproportion"); 00085 00086 // [n,T] = size(y); 00087 const int n = Y.rows(); 00088 const int t = Y.cols(); 00089 00090 // [dummy, y] = max(y,[],2); 00091 T* work = new T[Y.getSize()]; 00092 unsigned long* y = new unsigned long[n]; 00093 00094 // maxPerRow(Y.getData(), n, t, y); 00095 indicesOfMax(Y.getData(), n, t, y, work, 2); 00096 00097 delete[] work; 00098 00099 // for t = 1:T, 00100 // classes{t}.idx = find(y == t); 00101 // end 00102 00103 int* nSamples = new int[t]; 00104 int nva = 0; 00105 unsigned long* idx = new unsigned long[n]; 00106 unsigned long* it_idx = idx; 00107 00108 // for t = 1:T, 00109 for(int i=0; i<t; ++i) 00110 { 00111 // nSamples(t) = numel(classes{t}.idx); 00112 indicesOfEqualsTo<unsigned long>(y, n, i, it_idx, nSamples[i]); 00113 00114 // nva = nva + floor(fraction*nSamples(t)); 00115 nva += static_cast<int>(std::floor(fraction*nSamples[i])); 00116 00117 it_idx += nSamples[i]; 00118 } 00119 00120 delete[] y; 00121 00122 gMat2D<unsigned long>* m_indices = new gMat2D<unsigned long>(nSplits, n); 00123 unsigned long* indices = m_indices->getData(); 00124 00125 int count_tr; 00126 int count_va; 00127 00128 // for state = 1:nSplits, 00129 for(int state=0; state<nSplits; ++state) 00130 { 00131 // %% Shuffle each class 00132 count_tr = 0; 00133 count_va = n-nva; 00134 00135 it_idx = idx; 00136 00137 // for t = 1:T, 00138 for(int i=0; i<t; ++i) 00139 { 00140 const int nsamples = nSamples[i]; 00141 00142 if(nsamples == 0) 00143 continue; 00144 00145 randperm(nsamples, it_idx, false); 00146 00147 int last = nsamples - static_cast<int>(std::floor(nsamples*fraction)); 00148 00149 00150 copy(indices+(nSplits*count_tr)+state, it_idx, last, nSplits, 1); 00151 00152 copy(indices+(nSplits*count_va)+state, it_idx+last, nsamples-last, nSplits, 1); 00153 00154 count_tr += last; 00155 count_va += nsamples-last; 00156 it_idx += nsamples; 00157 00158 } 00159 00160 } 00161 00162 delete[] nSamples; 00163 delete[] idx; 00164 00165 00166 gMat2D<unsigned long>* m_lasts = new gMat2D<unsigned long>(nSplits, 1); 00167 set(m_lasts->getData(), (unsigned long) (n-nva), nSplits); 00168 00169 00170 GurlsOptionsList* split = new GurlsOptionsList("split"); 00171 split->addOpt("indices", new OptMatrix<gMat2D<unsigned long> >(*m_indices)); 00172 split->addOpt("lasts", new OptMatrix<gMat2D<unsigned long> >(*m_lasts)); 00173 00174 return split; 00175 } 00176 00177 } 00178 00179 #endif //_GURLS_SPLITHO_H_