GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
gurls::SplitHo< T > Class Template Reference

SplitHoMulti is the sub-class of Split that splits data into one or more pairs of training and test samples.

#include <splitho.h>

Inheritance diagram for gurls::SplitHo< T >:
Collaboration diagram for gurls::SplitHo< T >:

List of all members.

Public Member Functions

GurlsOptionsListexecute (const gMat2D< T > &X, const gMat2D< T > &Y, const GurlsOptionsList &opt) throw (gException)
 Splits data into one or more pairs of training and test samples, to be used for cross-validation.

Static Public Member Functions

static Split< T > * factory (const std::string &id) throw (BadSplitCreation)
 Factory function returning a pointer to the newly created object.

Detailed Description

template<typename T>
class gurls::SplitHo< T >

Definition at line 58 of file splitho.h.


Member Function Documentation

template<typename T >
GurlsOptionsList * gurls::SplitHo< T >::execute ( const gMat2D< T > &  X,
const gMat2D< T > &  Y,
const GurlsOptionsList opt 
) throw (gException) [virtual]

The fraction of samples for the validation set is specified in the field hoproportion of opt, and the number of pairs is specified in the field nholdouts of opt

Parameters:
Xnot used
Ylabels matrix
optoptions with the following field
  • hoproportion (default)
  • nholdouts (default)
Returns:
adds to opt the field split, which is a list containing the following fields:
  • indices = nholdoutsxn matrix, each row contains the indices of training and validation samples
  • lasts = nholdoutsx1 array, each row contains the number of elements of training set, which will be build taking the samples corresponding to the first lasts+1 elements of indices, the remainder indices will be used for validation.

Implements gurls::Split< T >.

Definition at line 78 of file splitho.h.

{
//    nSplits = opt.nholdouts;
    const int nSplits = static_cast<int>(opt.getOptAsNumber("nholdouts"));

//    fraction = opt.hoproportion;
    const double fraction = opt.getOptAsNumber("hoproportion");

//    [n,T] = size(y);
    const int n = Y.rows();
    const int t = Y.cols();

//    [dummy, y] = max(y,[],2);
    T* work = new T[Y.getSize()];
    unsigned long* y = new unsigned long[n];

//    maxPerRow(Y.getData(), n, t, y);
    indicesOfMax(Y.getData(), n, t, y, work, 2);

    delete[] work;

//    for t = 1:T,
//        classes{t}.idx = find(y == t);
//    end

    int* nSamples = new int[t];
    int nva = 0;
    unsigned long* idx = new unsigned long[n];
    unsigned long* it_idx = idx;

//    for t = 1:T,
    for(int i=0; i<t; ++i)
    {
//        nSamples(t) = numel(classes{t}.idx);
        indicesOfEqualsTo<unsigned long>(y, n, i, it_idx, nSamples[i]);

//        nva = nva + floor(fraction*nSamples(t));
        nva += static_cast<int>(std::floor(fraction*nSamples[i]));

        it_idx += nSamples[i];
    }

    delete[] y;

    gMat2D<unsigned long>* m_indices = new gMat2D<unsigned long>(nSplits, n);
    unsigned long* indices = m_indices->getData();

    int count_tr;
    int count_va;

//    for state = 1:nSplits,
    for(int state=0; state<nSplits; ++state)
    {
//    %% Shuffle each class
        count_tr = 0;
        count_va = n-nva;

        it_idx = idx;

//        for t = 1:T,
        for(int i=0; i<t; ++i)
        {
            const int nsamples = nSamples[i];

            if(nsamples == 0)
                continue;

            randperm(nsamples, it_idx, false);

            int last = nsamples - static_cast<int>(std::floor(nsamples*fraction));


            copy(indices+(nSplits*count_tr)+state, it_idx, last, nSplits, 1);

            copy(indices+(nSplits*count_va)+state, it_idx+last, nsamples-last, nSplits, 1);

            count_tr += last;
            count_va += nsamples-last;
            it_idx += nsamples;

        }

    }

    delete[] nSamples;
    delete[] idx;


    gMat2D<unsigned long>* m_lasts = new gMat2D<unsigned long>(nSplits, 1);
    set(m_lasts->getData(), (unsigned long) (n-nva), nSplits);


    GurlsOptionsList* split = new GurlsOptionsList("split");
    split->addOpt("indices", new OptMatrix<gMat2D<unsigned long> >(*m_indices));
    split->addOpt("lasts", new OptMatrix<gMat2D<unsigned long> >(*m_lasts));

    return split;
}
template<typename T>
static Split<T>* gurls::Split< T >::factory ( const std::string &  id) throw (BadSplitCreation) [inline, static, inherited]
Warning:
The returned pointer is a plain, un-managed pointer. The calling function is responsible of deallocating the object.

Definition at line 95 of file split.h.

    {
        if(id == "ho")
            return new SplitHo<T>;

        throw BadSplitCreation(id);
    }

The documentation for this class was generated from the following file:
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends