Saving a tensorflow model and loading it for further training

Issue

I am training a CNN model I made with TensorFlow on a relatively large dataset (27G) in python. As my RAM is not capable of importing all of the data to feed into the model, I instead do something like this:

  1. read dataset(0) through dataset(100)
  2. do the data processing
  3. train the model for 20 epochs
  4. save the model
  5. read dataset(101) through (200)

and re-do the processing and training for the forthcoming data. I use model.save(filepath) function, which saves the entire model (weights, optimizer state and…).

The following is the simple code to save and load the model after each training session:

loop():
   dataprocessing

training_data, training_labels = processed_data()

mod = load_model('Mytf.h5')
history.append(mod.fit(training_data,training_label,batch_size=10,epochs=40))
mod.save('Mytf.h5') 
    
del training_data
del training_label

However, I end up starting with around the same loss (mse) for the new data after each training session and end up again with the same loss after 20 epochs (e.g. after training).

Is this approach correct? or am I missing a fundamental concept?

If this is incorrect, does TensorFlow facilitate the ability for a program to train to a certain point, for the program to then generate new processed data to be fed into the model in the same epoch? (e.g. train processed data of dataset(0) to dataset(100), say 1/3rd through the epoch, the model stops training and the program processes new data and then feeds into the model at that exact state).

ps: I made sure that I’m saving the model correctly, simply by just loading it and see if it results in the same accuracy/loss it did when the training for that specific data set ended.

Solution

I would rather do

Repeat 20 time:
     read next 100 datasets/datapoints
     do the data processing
     train the model for 1 epoch
     save the model

When you run 20 epochs on n datapoints then you might be overfitting to them and when it sees nexxt n datapoints then it has to learn afresh. Rather run only 1 epoch on each set and do it for n times.

Also, in the approach you are following rather in fitting for 20 epochs, do early stopping to avoid overfitting.

Answered By – mujjiga

Answer Checked By – Gilberto Lyons (AngularFixing Admin)

Leave a Reply

Your email address will not be published.