Batch size when running tf.Strategy vs. batch()


I want to display the batch size when running a tf.distribute strategy. I do this by creating a custom Keras layer as so:

class DebugLayer(tf.keras.layers.Layer):
    def __init__(self):

    def build(self, input_shape):

    def call(self, inputs):
        print_op = tf.print("******Shape is:", tf.shape(inputs) , name='shapey')
        #print_op = tf.print("Debug output:", loss, y_true, y_true.shape)
        with tf.control_dependencies([print_op]):
            return tf.identity(inputs)

Q1: Number of examples per worker per batch

If I run with one worker it gives 128 for batch size which is what I set in my dataset flow .batch(128).

If I run with two workers, each worker outputs 128. I want to know how many examples are being run on each worker? How many examples are being run simultaneously?

Q2: correct steps_per_epoch

In my call, I specify steps_per_epoch and have a .repeat in my dataflow. If my training set consists of 1024 samples, I have 2 workers, and my .batch is set to 128, what should the steps_per_epoch be set to for one epoch?


When using operations there is a .batch() method that is typically applied to the data. Let’s say that value is 128. That will be the number of total examples that will be run per batch regardless of number of workers. if…

  • 1 worker is used, it will run 128 examples per training step.
  • 2 workers are used, each will run 64 examples per training step.
  • 3 workers are used, each will run roughly 42 examples per training step.

For the 3 worker case, I’m not sure the exact number since 128/3 is not an integer value.

For setting steps_per_epoch, divide the total number of samples by the batch size that you set in .batch(). So, for my example in the question it would be 8, which is 1024/128.

This is somewhat inconvenient because you need to know the number of training examples and if they change you need to adjust the steps_per_epoch value. Also, if not an integer multiple you need to decide if you should round, floor, or ceiling the steps_per_epoch value.

Answered By – Robert Lugg

Answer Checked By – Gilberto Lyons (AngularFixing Admin)

