Training a single model jointly over multiple datasets in tensorflow


I want to train a single variational autoencoder model or even a standard autoencoder over many datasets jointly (e.g. mnist, cifar, svhn, etc. where all the images in the datasets are resized to be the same input shape). Here is the VAE tutorial in tensorflow which I am using as a starting point:

For training the model, I would want to sample (choose) a dataset from my set of datasets and then obtain a batch of images from that dataset at each gradient update step in the training loop. I could combine all the datasets into one big dataset, but I want to leverage that the images in a given batch come from the same dataset as side information (I’m still figuring out this part, but the details aren’t too important since my question focuses on the data pipeline).

I am not sure how exactly to go about the data pipeline setup. The tutorial specifies the dataset pipeline as follows:

train_dataset = (
test_dataset = (

where train_images and test_images are the processed MNIST data. So it creates a tensorflow dataset, shuffles the entire dataset, and batches the data into batches of size batch_size. In my case, I assume I would want to create a separate train_dataset/test_dataset for each dataset in my set of datasets (e.g. cifar_train_dataset/cifar_test_dataset, mnist_train_dataset/mnist_test_dataset, etc.).

When it comes to training, they specify the procedure as follows:

for epoch in range(1, epochs + 1):
  for train_x in train_dataset:
    train_step(model, train_x, optimizer)

  loss = tf.keras.metrics.Mean()
  for test_x in test_dataset:
    loss(compute_loss(model, test_x))
  elbo = -loss.result()
  print('Epoch: {}, Test set ELBO: {})

Instead of specifying epochs, I could just specify a total number of training iterations/steps (e.g. 500,000). Within each training step, I would want to sample a dataset from the set of datasets (assuming equal probabilities) instead of assuming a single training dataset as above.

Now comes the part I’m not sure about. The line for train_x in train_dataset is a loop that iterates over the entire dataset in batches. Instead, I would just want to obtain a single batch of images for the given dataset I have sampled, make a model update, and repeat the process. However, I am not sure if specifying datasets as I have described above provides this flexibility? Is there any way to index a batch/obtain a single batch as opposed to iterating over all batches.

In summary, I want to train a single model over multiple datasets by sampling a batch of images from a given dataset at each training step when making model updates. I am completely open to other suggestions and approaches that address this problem. Thanks!


If I understand your question correctly, you want to control the number of batches that you pull from your train and test sets, instead of iterating over them completely before doing an update. You can turn your dataset into an iterator by wrapping it in iter() and use the next() method to grab the next batch.


import numpy as np
import tensorflow as tf

# fake mnist data
train_imgs = tf.random.normal([100, 28, 28, 1])
test_imgs = tf.random.normal([100, 28, 28, 1])
train_labels = tf.one_hot(
    tf.random.uniform([100,], minval=0, maxval=10, dtype=tf.int64), 10)
test_labels = tf.one_hot(
    tf.random.uniform([100,], minval=0, maxval=10, dtype=tf.int64), 10)

# create train/test dataset
train_ds =, train_labels))
train_ds = train_ds.repeat().shuffle(1 << 6).batch(8)
test_ds =, train_labels))
test_ds = test_ds.repeat().shuffle(1 << 6).batch(8)

# simple mnist network
x_in = tf.keras.Input((28, 28, 1))
x = tf.keras.layers.Flatten()(x_in)
x = tf.keras.layers.Dense(100)(x)
x_out = tf.keras.layers.Dense(10)(x)

# simple mnist model
model = tf.keras.Model(x_in, x_out)

# make datasets iterators
train_iter = iter(train_ds)
test_iter = iter(test_ds)

# loss
def xent_loss(y_true, y_pred):
    ce = tf.keras.losses.CategoricalCrossentropy()
    return ce(y_true, y_pred)
# simple training loop where you control the batches per epoch
# for your train and test datasets

for epoch in range(NUM_EPOCHS):
    train_losses = []
    # train
    for _ in range(NUM_TRAIN_BATCHES_PER_EPOCH):
        X_train, y_train = next(train_iter)
        y_hat = model(X_train)
        loss = xent_loss(y_train, y_hat)
        # do gradient update ...
    # report train loss
    print(f"epoch: {epoch}\ttrain_loss: {np.mean(train_losses):.4f}")
    train_losses = []
    # validate
    test_losses = []
    for _ in range(NUM_TEST_BATCHES_PER_EPOCH):
        X_test, y_test = next(test_iter)
        y_hat = model(X_test)
        loss = xent_loss(y_test, y_hat)
    # report validation loss
    print(f"epoch: {epoch}\ttest_loss: {np.mean(test_losses):.4f}")
    test_losses = []
    print('-' * 40)

# epoch: 0  train_loss: 7.3092
# epoch: 0  test_loss: 7.3427
# ----------------------------------------
# epoch: 1  train_loss: 6.8050
# epoch: 1  test_loss: 8.4867
# ----------------------------------------

Answered By – o-90

Answer Checked By – Marie Seifert (AngularFixing Admin)

Leave a Reply

Your email address will not be published.