GURLS++  2.0.00
C++ Implementation of GURLS Matlab Toolbox
splitho.h
00001 /*
00002  * The GURLS Package in C++
00003  *
00004  * Copyright (C) 2011-1013, IIT@MIT Lab
00005  * All rights reserved.
00006  *
00007  * authors:  M. Santoro
00008  * email:   msantoro@mit.edu
00009  * website: http://cbcl.mit.edu/IIT@MIT/IIT@MIT.html
00010  *
00011  * Redistribution and use in source and binary forms, with or without
00012  * modification, are permitted provided that the following conditions
00013  * are met:
00014  *
00015  *     * Redistributions of source code must retain the above
00016  *       copyright notice, this list of conditions and the following
00017  *       disclaimer.
00018  *     * Redistributions in binary form must reproduce the above
00019  *       copyright notice, this list of conditions and the following
00020  *       disclaimer in the documentation and/or other materials
00021  *       provided with the distribution.
00022  *     * Neither the name(s) of the copyright holders nor the names
00023  *       of its contributors or of the Massacusetts Institute of
00024  *       Technology or of the Italian Institute of Technology may be
00025  *       used to endorse or promote products derived from this software
00026  *       without specific prior written permission.
00027  *
00028  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00029  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00030  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00031  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00032  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00033  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00034  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00035  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00036  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00037  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00038  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00039  * POSSIBILITY OF SUCH DAMAGE.
00040  */
00041 
00042 
00043 #ifndef _GURLS_SPLITHO_H_
00044 #define _GURLS_SPLITHO_H_
00045 
00046 
00047 #include "gurls++/split.h"
00048 #include "gurls++/gmath.h"
00049 
00050 namespace gurls {
00051 
00057 template <typename T>
00058 class SplitHo: public Split<T>
00059 {
00060 public:
00074     GurlsOptionsList* execute(const gMat2D<T>& X, const gMat2D<T>& Y, const GurlsOptionsList& opt) throw(gException);
00075 };
00076 
00077 template<typename T>
00078 GurlsOptionsList *SplitHo<T>::execute(const gMat2D<T>& /*X*/, const gMat2D<T>& Y, const GurlsOptionsList &opt) throw(gException)
00079 {
00080 //    nSplits = opt.nholdouts;
00081     const int nSplits = static_cast<int>(opt.getOptAsNumber("nholdouts"));
00082 
00083 //    fraction = opt.hoproportion;
00084     const double fraction = opt.getOptAsNumber("hoproportion");
00085 
00086 //    [n,T] = size(y);
00087     const int n = Y.rows();
00088     const int t = Y.cols();
00089 
00090 //    [dummy, y] = max(y,[],2);
00091     T* work = new T[Y.getSize()];
00092     unsigned long* y = new unsigned long[n];
00093 
00094 //    maxPerRow(Y.getData(), n, t, y);
00095     indicesOfMax(Y.getData(), n, t, y, work, 2);
00096 
00097     delete[] work;
00098 
00099 //    for t = 1:T,
00100 //        classes{t}.idx = find(y == t);
00101 //    end
00102 
00103     int* nSamples = new int[t];
00104     int nva = 0;
00105     unsigned long* idx = new unsigned long[n];
00106     unsigned long* it_idx = idx;
00107 
00108 //    for t = 1:T,
00109     for(int i=0; i<t; ++i)
00110     {
00111 //        nSamples(t) = numel(classes{t}.idx);
00112         indicesOfEqualsTo<unsigned long>(y, n, i, it_idx, nSamples[i]);
00113 
00114 //        nva = nva + floor(fraction*nSamples(t));
00115         nva += static_cast<int>(std::floor(fraction*nSamples[i]));
00116 
00117         it_idx += nSamples[i];
00118     }
00119 
00120     delete[] y;
00121 
00122     gMat2D<unsigned long>* m_indices = new gMat2D<unsigned long>(nSplits, n);
00123     unsigned long* indices = m_indices->getData();
00124 
00125     int count_tr;
00126     int count_va;
00127 
00128 //    for state = 1:nSplits,
00129     for(int state=0; state<nSplits; ++state)
00130     {
00131 //    %% Shuffle each class
00132         count_tr = 0;
00133         count_va = n-nva;
00134 
00135         it_idx = idx;
00136 
00137 //        for t = 1:T,
00138         for(int i=0; i<t; ++i)
00139         {
00140             const int nsamples = nSamples[i];
00141 
00142             if(nsamples == 0)
00143                 continue;
00144 
00145             randperm(nsamples, it_idx, false);
00146 
00147             int last = nsamples - static_cast<int>(std::floor(nsamples*fraction));
00148 
00149 
00150             copy(indices+(nSplits*count_tr)+state, it_idx, last, nSplits, 1);
00151 
00152             copy(indices+(nSplits*count_va)+state, it_idx+last, nsamples-last, nSplits, 1);
00153 
00154             count_tr += last;
00155             count_va += nsamples-last;
00156             it_idx += nsamples;
00157 
00158         }
00159 
00160     }
00161 
00162     delete[] nSamples;
00163     delete[] idx;
00164 
00165 
00166     gMat2D<unsigned long>* m_lasts = new gMat2D<unsigned long>(nSplits, 1);
00167     set(m_lasts->getData(), (unsigned long) (n-nva), nSplits);
00168 
00169 
00170     GurlsOptionsList* split = new GurlsOptionsList("split");
00171     split->addOpt("indices", new OptMatrix<gMat2D<unsigned long> >(*m_indices));
00172     split->addOpt("lasts", new OptMatrix<gMat2D<unsigned long> >(*m_lasts));
00173 
00174     return split;
00175 }
00176 
00177 }
00178 
00179 #endif //_GURLS_SPLITHO_H_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends