![]() |
GURLS++
2.0.00
C++ Implementation of GURLS Matlab Toolbox
|
SplitHoMulti is the sub-class of Split that splits data into one or more pairs of training and test samples.
#include <splitho.h>


Public Member Functions | |
| GurlsOptionsList * | execute (const gMat2D< T > &X, const gMat2D< T > &Y, const GurlsOptionsList &opt) throw (gException) |
| Splits data into one or more pairs of training and test samples, to be used for cross-validation. | |
Static Public Member Functions | |
| static Split< T > * | factory (const std::string &id) throw (BadSplitCreation) |
| Factory function returning a pointer to the newly created object. | |
| GurlsOptionsList * gurls::SplitHo< T >::execute | ( | const gMat2D< T > & | X, |
| const gMat2D< T > & | Y, | ||
| const GurlsOptionsList & | opt | ||
| ) | throw (gException) [virtual] |
The fraction of samples for the validation set is specified in the field hoproportion of opt, and the number of pairs is specified in the field nholdouts of opt
| X | not used |
| Y | labels matrix |
| opt | options with the following field
|
Implements gurls::Split< T >.
Definition at line 78 of file splitho.h.
{
// nSplits = opt.nholdouts;
const int nSplits = static_cast<int>(opt.getOptAsNumber("nholdouts"));
// fraction = opt.hoproportion;
const double fraction = opt.getOptAsNumber("hoproportion");
// [n,T] = size(y);
const int n = Y.rows();
const int t = Y.cols();
// [dummy, y] = max(y,[],2);
T* work = new T[Y.getSize()];
unsigned long* y = new unsigned long[n];
// maxPerRow(Y.getData(), n, t, y);
indicesOfMax(Y.getData(), n, t, y, work, 2);
delete[] work;
// for t = 1:T,
// classes{t}.idx = find(y == t);
// end
int* nSamples = new int[t];
int nva = 0;
unsigned long* idx = new unsigned long[n];
unsigned long* it_idx = idx;
// for t = 1:T,
for(int i=0; i<t; ++i)
{
// nSamples(t) = numel(classes{t}.idx);
indicesOfEqualsTo<unsigned long>(y, n, i, it_idx, nSamples[i]);
// nva = nva + floor(fraction*nSamples(t));
nva += static_cast<int>(std::floor(fraction*nSamples[i]));
it_idx += nSamples[i];
}
delete[] y;
gMat2D<unsigned long>* m_indices = new gMat2D<unsigned long>(nSplits, n);
unsigned long* indices = m_indices->getData();
int count_tr;
int count_va;
// for state = 1:nSplits,
for(int state=0; state<nSplits; ++state)
{
// %% Shuffle each class
count_tr = 0;
count_va = n-nva;
it_idx = idx;
// for t = 1:T,
for(int i=0; i<t; ++i)
{
const int nsamples = nSamples[i];
if(nsamples == 0)
continue;
randperm(nsamples, it_idx, false);
int last = nsamples - static_cast<int>(std::floor(nsamples*fraction));
copy(indices+(nSplits*count_tr)+state, it_idx, last, nSplits, 1);
copy(indices+(nSplits*count_va)+state, it_idx+last, nsamples-last, nSplits, 1);
count_tr += last;
count_va += nsamples-last;
it_idx += nsamples;
}
}
delete[] nSamples;
delete[] idx;
gMat2D<unsigned long>* m_lasts = new gMat2D<unsigned long>(nSplits, 1);
set(m_lasts->getData(), (unsigned long) (n-nva), nSplits);
GurlsOptionsList* split = new GurlsOptionsList("split");
split->addOpt("indices", new OptMatrix<gMat2D<unsigned long> >(*m_indices));
split->addOpt("lasts", new OptMatrix<gMat2D<unsigned long> >(*m_lasts));
return split;
}
| static Split<T>* gurls::Split< T >::factory | ( | const std::string & | id | ) | throw (BadSplitCreation) [inline, static, inherited] |