How to split a DataSetIterator into testing and training iterator?

Issue

I am using Deeplearning4j and datavec, and I have a DataSetIterator object that represents all of my data, which is a time series. How can I split this into training and testing iterators? I check and the DataSetIterator Class’s methods are deprecated. Thank you.

Solution

Iterate through your DataSetIterator and for each DataSet entry, create two new DataSets, each for train and test.

The key is to use the splitTestAndTrain method, which accepts a double fractionTrain that will specify the amount of data to be trained (the rest to be tested). There are different overloads of the method, so you can choose the one that fits your needs best. If you wish to add all train and test datasets to a common iterator, you could store them in two different Lists, and get their corresponding iterator later. Something like:

List<DataSet> trainList = new ArrayList<>();
List<DataSet> testList= new ArrayList<>();

while (yourDataSetIterator.hasNext())
{
    DataSet ds = yourDataSetIterator.next();
    SplitTestAndTrain splData = ds.splitTestAndTrain(0.5); //half for each         
    DataSet trainDs = splData.getTrain();
    trainList.add(trainDs);
    DataSet testDs  = splData.getTest();
    testList.add(testDs);
    (...)
}

Iterator<DataSet> trainIterator = trainList.iterator(); 
Iterator<DataSet> testIterator  = testList.iterator(); 

As I don’t really know the specific details of this library, the example just creates "basic" iterators. This may be customized so you create DataSetIterators instead.

Note that you may also need to shuffle the DataSet before splitting it (ds.shuffle()). You could find some examples here


If you wish to split it in an specific way, you could label the different entries and find the max index of the test dataset; Then, call splitTestAndTrain(int max) method, which specifically splits the dataset regarding the max parameter. The sortByLabel method could also be helpful here.


Adam Gibson gave a great suggestion on the comments regarding other mechanism in order to split the DataSetIterator, which also seem to be a "more natural" way to do it, the DataSetIteratorSplitter.

It offers the getTrainIterator() and getTestIterator() methods, which return the library’s specific iterator, DataSetIterator.

Answered By – aran

Answer Checked By – Robin (AngularFixing Admin)

Leave a Reply

Your email address will not be published.