Scikit learn train_test_split into Pytorch Dataloader

Issue

I have a dataset for binary classification with PNGs titled as in the attachment below, where the first 0 or 1 in the title determines its class. They’re in a folder called "annotation_class", and I have a small script to separate these:

import cv2,glob
import numpy as np
from sklearn.model_selection import train_test_split

filelist = glob.glob('annotation_class'+'/*.png')
size_row, size_col = 256, 256
X,y = [],[]

for name in filelist:
        img = cv2.imread(name) 
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
        img = cv2.resize(img,(size_row, size_col))
        X.append(img)
        y.append(int(name.split('\\')[-1].split('_')[1]))


x_train, x_test, y_train, y_test= train_test_split(X, y, test_size=0.2, train_size=0.8, random_state=4)

The returns are all lists. I’m using Pytorch for this project and would like to make a custom Dataset to use Dataloader, but I’m not sure how best to include these after I’ve used train_test_split. Should I scrap that altogether and use something else? I’d like to end up with two DataLoader’s for training and testing.

sample titles

Solution

You don’t have to rewrite. You can reuse your core data loading logic inside PyTorch Dataset

import cv2,glob
import numpy as np
from sklearn.model_selection import train_test_split

from torch.utils.data import Dataset

class MyCoolDataset(Dataset):

    def __init__(self, dir, train=True):
        filelist = glob.glob(dir + '/*.png')
        ...
        # all your data loading logic using cv2, glob ..
        x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, train_size=0.8, random_state=4)
        
        # two modes - train and test
        if train:
            self.x_data, self.y_data = x_train, y_train
        else:
            self.x_data, self.y_data = x_test, y_test
    
    def __getitem__(self, i):
        return self.x_data[i], self.y_data[i]

Then use a DataLoader as usual

dl = DataLoader(MyCoolDataset(...), batch_size=...)
for X, Y in dl:
    pass

Answered By – ayandas

Answer Checked By – Timothy Miller (AngularFixing Admin)

Leave a Reply

Your email address will not be published.