Using generators in Python to train machine learning models

Jessica YungMachine Learning, Programming

If you want to train a machine learning model on a large dataset such as ImageNet, especially if you want to use GPUs, you’ll need to think about how you can stay within your GPU or CPU’s memory limits. Generators are a great way of doing this in Python.

What is a generator?

A generator is a function that behaves like an iterator. An iterator loops (iterates) through elements of an object, like items in a list or keys in a dictionary. A generator is often used like an array, but there are a few differences:

  • It does not hold results in memory,
  • It may take longer to run (Trade off using more time for using less space),
  • It is ‘lazy’: it does not compute results till you need them,
    • That is, your list is constructed in bits and pieces, with each element calculated when you ask for the element.
  • You can only iterate over them once.

You’ll get a better feel for what generators are as we go through examples in this post.

How to code a generator

The first and more tedious way of coding a generator is defining a function that loops over elements in an object and yields elements as it loops.

Method 1:

input_list =[1,2,3,4,5]

def my_generator(my_list):
    print("This runs the first time you call next().")
    for i in my_list:
        yield i*i

gen1 = my_generator(input_list)

# This runs the first time you call next(). <- printout
# 1

# 4 (since 2*2=4)
# Full 'list' would be [1, 4, 9, 16, 25]

# After running out of elements
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
# StopIteration

Yield is used like return, but (1) it returns a generator, and (2) when you call the generator function, the function does not run completely. [1] The function just returns the generator object. Every time you call next() on the generator object, the generator runs from where you stopped before to the next occurrence of yield.

Method 2:

The second way of coding generators is similar to that of coding list comprehensions. It’s much more compact than the previous method:

gen2 = (i*i for i in my_list)

When the generator has run out of entries, it will give you a StopIteration exception.

Example: Using generators in machine learning models

I think the hardest part of learning a new technique is figuring out when to incorporate the technique into your code. Examples are a great way to accelerate that learning.

An anti-example: Range in Python 3

Before we go into an example of a generator, let’s look at what isn’t a generator.

You’ve likely come across range in Python 3 (or xrange in Python 2) when making a for loop:

for i in range(10):

This generates the list [0, 1, ..., 10].

You may have heard that range in Python 3 is now a generator. It acts like a generator in that it doesn’t produce the entire list [0,1,...,10] in memory, but it really isn’t one! You can check it isn’t a generator by trying to call next(range(10)). For more details, see Oleh Prypin’s answer on StackOverflow.

Using generators in machine learning models

Recall that a big benefit of using generators is saving memory. So it’d be great to use generators in applications that seem to need a lot of memory, but where you really want to save memory.

One example is training machine learning models that take in a lot of data on GPUs. GPUs don’t have much memory and you can often get MemoryErrors. So one way out is to use a generator to read in images to input to the model.

The outline of the generator goes like this (the code is heavily adapted from code from Udacity):

import matplotlib.image as mpimg

def shuffle(samples):
    # NOTE: this is pseudocode
    return shuffled samples

def generator(samples, batch_size=32):
    Yields the next training batch.
    Suppose `samples` is an array [[image1_filename,label1], [image2_filename,label2],...].
    num_samples = len(samples)
    while True: # Loop forever so the generator never terminates

        # Get index to start each batch: [0, batch_size, 2*batch_size, ..., max multiple of batch_size <= num_samples]
        for offset in range(0, num_samples, batch_size):
            # Get the samples you'll use in this batch
            batch_samples = samples[offset:offset+batch_size]

            # Initialise X_train and y_train arrays for this batch
            X_train = []
            y_train = []

            # For each example
            for batch_sample in batch_samples:
                # Load image (X)
                filename = './common_filepath/'+batch_sample[0]
                image = mpimg.imread(filename)
                # Read label (y)
                y = batch_sample[1]
                # Add example to arrays

            # Make sure they're numpy arrays (as opposed to lists)
            X_train = np.array(X_train)
            y_train = np.array(y_train)

            # The generator-y part: yield the next training batch            
            yield X_train, y_train

# Import list of train and validation data (image filenames and image labels)
# Note this is not valid code.
train_samples = ...
validation_samples = ...

# Create generator
train_generator = generator(train_samples, batch_size=32)
validation_generator = generator(validation_samples, batch_size=32)

# Use generator to train neural network in Keras

# Create model in Keras
from keras.models import Sequential
from keras.layers import Dense, Activation

model = Sequential([
    Dense(32, input_shape=(784,)),

# Fit model using generator
                    nb_val_samples=len(validation_samples), nb_epoch=100)

The full code in its original context can be found on GitHub as part of my attempt on the Behavioural Cloning project in Udacity’s Self-Driving Car Engineer Nanodegree.

Using a generator, you only need to keep the images for your training batch in memory as opposed to all your training images. Note that you may still get MemoryErrors from, for example, having too many parameters in your network.