Restructure code to avoid for loops in training loop?

Issue

I am defining a train function which I pass in a
data_loader as a dict.

  • data_loader[‘train’]: consists of train data
  • data_loader[‘val’] consists of validation data.

I created a loop which iterates through which phase I am in (either train or val) and sets the model to either model.train() or model.eval() accordingly. However I feel I have too many nested for loops here making it computationally expensive. Could anyone recommend a better way of going about constructing my train function? Should I create a separate function for validating instead?

Below is what I have so far:

#Make train function (simple at first)
def train_network(model, optimizer, data_loader, no_epochs):

  total_epochs = notebook.tqdm(range(no_epochs))

  for epoch in total_epochs:
    
    for phase in ['train', 'val']:
      if phase == 'train':
        model.train()
      else:
        model.eval()

      for i, (images, g_truth) in enumerate(data_loader[phase]):
        images = images.to(device)
        g_truth = g_truth.to(device)

Solution

The outer-most and inner-most for loops are common when writing training scripts.

The most common pattern I see is to do:

total_epochs = notebook.tqdm(range(no_epochs))

for epoch in total_epochs:
    # Training
    for i, (images, g_truth) in enumerate(train_data_loader):
        model.train()
        images = images.to(device)
        g_truth = g_truth.to(device)
        ...

    # Validating
    for i, (images, g_truth) in enumerate(val_data_loader):
        model.eval()
        images = images.to(device)
        g_truth = g_truth.to(device)
        ...

If you need to use your previous variable data_loader, you can replace train_data_loader with data_loader["train"] and val_data_loader with data_loader["val"]

This layout is common because we generally want to do some things differently when validating as opposed to training. This also structures the code better and avoids a lot of if phase == "train" that you might need at different parts of your inner-most loop. This does however mean that you might need to duplicate some code. The trade off is generally accepted and your original code might be considered if we had 3 or more phases, like multiple validation phases or an evaluation phase as well.

Answered By – Zoom

Answer Checked By – Candace Johnson (AngularFixing Volunteer)

Leave a Reply

Your email address will not be published.