Using generators in Python to train machine learning models

Jessica YungMachine Learning, ProgrammingLeave a Comment

If you want to train a machine learning model on a large dataset such as ImageNet, especially if you want to use GPUs, you’ll need to think about how you can stay within your GPU or CPU’s memory limits. Generators are a great way of doing this in Python.

What is a generator?

A generator is a function that behaves like an iterator. An iterator loops (iterates) through elements of an object, like items in a list or keys in a dictionary. A generator is often used like an array, but there are a few differences:

  • It does not hold results in memory,
  • It may take longer to run (Trade off using more time for using less space),
  • It is ‘lazy’: it does not compute results till you need them,
    • That is, your list is constructed in bits and pieces, with each element calculated when you ask for the element.
  • You can only iterate over them once.

You’ll get a better feel for what generators are as we go through examples in this post.

How to code a generator

The first and more tedious way of coding a generator is defining a function that loops over elements in an object and yields elements as it loops.

Method 1:

Yield is used like return, but (1) it returns a generator, and (2) when you call the generator function, the function does not run completely. [1] The function just returns the generator object. Every time you call next() on the generator object, the generator runs from where you stopped before to the next occurrence of  yield.

Method 2:

The second way of coding generators is similar to that of coding list comprehensions. It’s much more compact than the previous method:

When the generator has run out of entries, it will give you a StopIteration exception.

Example: Using generators in machine learning models

I think the hardest part of learning a new technique is figuring out when to incorporate the technique into your code. Examples are a great way to accelerate that learning.

An anti-example: Range in Python 3

Before we go into an example of a generator, let’s look at what isn’t a generator.

You’ve likely come across range in Python 3 (or xrange in Python 2) when making a for loop:

This generates the list [0, 1, ..., 10].

You may have heard that range in Python 3 is now a generator. It acts like a generator in that it doesn’t produce the entire list [0,1,...,10] in memory, but it really isn’t one! You can check it isn’t a generator by trying to call next(range(10)). For more details, see Oleh Prypin’s answer on StackOverflow.

Using generators in machine learning models

Recall that a big benefit of using generators is saving memory. So it’d be great to use generators in applications that seem to need a lot of memory, but where you really want to save memory.

One example is training machine learning models that take in a lot of data on GPUs. GPUs don’t have much memory and you can often get MemoryErrors. So one way out is to use a generator to read in images to input to the model.

The outline of the generator goes like this (the code is heavily adapted from code from Udacity):

The full code in its original context can be found on GitHub as part of my attempt on the Behavioural Cloning project in Udacity’s Self-Driving Car Engineer Nanodegree.

Using a generator, you only need to keep the images for your training batch in memory as opposed to all your training images. Note that you may still get MemoryErrors from, for example, having too many parameters in your network.


Leave a Reply