LSTMs for Time Series in PyTorch

Jessica YungMachine Learning, UncategorizedLeave a Comment

I can’t believe how long it took me to get an LSTM to work in PyTorch!

There are many ways it can fail. Sometimes you get a network that predicts values way too close to zero.

In this post, we’re going to walk through implementing an LSTM for time series prediction in PyTorch. We’re going to use pytorch’s nn module so it’ll be pretty simple, but in case it doesn’t work on your computer, you can try the tips I’ve listed at the end that have helped me fix wonky LSTMs in the past.

What is an LSTM?

A Long-short Term Memory network (LSTM) is a type of recurrent neural network designed to overcome problems of basic RNNs so the network can learn long-term dependencies. Specifically, it tackles vanishing and exploding gradients – the phenomenon where, when you backpropagate through time too many time steps, the gradients either vanish (go to zero) or explode (get very large) because it becomes a product of numbers all greater or all less than one. You can learn more about LSTMs from Chris Olah’s excellent blog post. You can also read Hochreiter and Schmidhuber’s original paper (1997), which identifies the vanishing and exploding gradient problems and proposes the LSTM as a way of tackling those problems.

Generating data

First, let’s prepare some data. For this example I have generated some AR(5) data. I’ve included the details in my post on generating AR data. You can find the code to generate the data here.

LSTM code

Next, let’s build the network.

In PyTorch, you usually build your network as a class inheriting from nn.Module. You need to implement the forward(.) method, which is the forward pass. You then run the forward pass like this:

You can implement the LSTM from scratch, but here we’re going to use torch.nn.LSTM object. torch.nn is a bit like Keras – it’s a wrapper around lower-level PyTorch code that makes it faster to build models by giving you common layers so you don’t have to implement them yourself.

Training the LSTM

After defining the model, we define the loss function and optimiser and train the model:

Debugging RNNs in PyTorch

Setting up and training models can be very simple in PyTorch. However, sometimes RNNs can predict values very close to zero even when the data isn’t distributed like that. I’ve found the following tricks have helped:

  1. Try decreasing your learning rate if your loss is increasing, or increasing your learning rate if the loss is not decreasing.
  2. Try removing model.zero_grad() if you’re using that.
  3. Use nn.LSTMCell instead of nn.LSTM. (This is a weird one but it’s worked before.)
  4. Use more data if you can.

Hope this helps and all the best with your machine learning endeavours!



Leave a Reply