![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
SplitHoMulti is the sub-class of Split that splits data into one or more pairs of training and test samples.
#include <splitho.h>
Public Member Functions | |
GurlsOptionsList * | execute (const gMat2D< T > &X, const gMat2D< T > &Y, const GurlsOptionsList &opt) throw (gException) |
Splits data into one or more pairs of training and test samples, to be used for cross-validation. | |
Static Public Member Functions | |
static Split< T > * | factory (const std::string &id) throw (BadSplitCreation) |
Factory function returning a pointer to the newly created object. |
GurlsOptionsList * gurls::SplitHo< T >::execute | ( | const gMat2D< T > & | X, |
const gMat2D< T > & | Y, | ||
const GurlsOptionsList & | opt | ||
) | throw (gException) [virtual] |
The fraction of samples for the validation set is specified in the field hoproportion of opt, and the number of pairs is specified in the field nholdouts of opt
X | not used |
Y | labels matrix |
opt | options with the following field
|
Implements gurls::Split< T >.
Definition at line 78 of file splitho.h.
{ // nSplits = opt.nholdouts; const int nSplits = static_cast<int>(opt.getOptAsNumber("nholdouts")); // fraction = opt.hoproportion; const double fraction = opt.getOptAsNumber("hoproportion"); // [n,T] = size(y); const int n = Y.rows(); const int t = Y.cols(); // [dummy, y] = max(y,[],2); T* work = new T[Y.getSize()]; unsigned long* y = new unsigned long[n]; // maxPerRow(Y.getData(), n, t, y); indicesOfMax(Y.getData(), n, t, y, work, 2); delete[] work; // for t = 1:T, // classes{t}.idx = find(y == t); // end int* nSamples = new int[t]; int nva = 0; unsigned long* idx = new unsigned long[n]; unsigned long* it_idx = idx; // for t = 1:T, for(int i=0; i<t; ++i) { // nSamples(t) = numel(classes{t}.idx); indicesOfEqualsTo<unsigned long>(y, n, i, it_idx, nSamples[i]); // nva = nva + floor(fraction*nSamples(t)); nva += static_cast<int>(std::floor(fraction*nSamples[i])); it_idx += nSamples[i]; } delete[] y; gMat2D<unsigned long>* m_indices = new gMat2D<unsigned long>(nSplits, n); unsigned long* indices = m_indices->getData(); int count_tr; int count_va; // for state = 1:nSplits, for(int state=0; state<nSplits; ++state) { // %% Shuffle each class count_tr = 0; count_va = n-nva; it_idx = idx; // for t = 1:T, for(int i=0; i<t; ++i) { const int nsamples = nSamples[i]; if(nsamples == 0) continue; randperm(nsamples, it_idx, false); int last = nsamples - static_cast<int>(std::floor(nsamples*fraction)); copy(indices+(nSplits*count_tr)+state, it_idx, last, nSplits, 1); copy(indices+(nSplits*count_va)+state, it_idx+last, nsamples-last, nSplits, 1); count_tr += last; count_va += nsamples-last; it_idx += nsamples; } } delete[] nSamples; delete[] idx; gMat2D<unsigned long>* m_lasts = new gMat2D<unsigned long>(nSplits, 1); set(m_lasts->getData(), (unsigned long) (n-nva), nSplits); GurlsOptionsList* split = new GurlsOptionsList("split"); split->addOpt("indices", new OptMatrix<gMat2D<unsigned long> >(*m_indices)); split->addOpt("lasts", new OptMatrix<gMat2D<unsigned long> >(*m_lasts)); return split; }
static Split<T>* gurls::Split< T >::factory | ( | const std::string & | id | ) | throw (BadSplitCreation) [inline, static, inherited] |