I can’t believe how long it took me to get an LSTM to work in PyTorch!
There are many ways it can fail. Sometimes you get a network that predicts values way too close to zero.
In this post, we’re going to walk through implementing an LSTM for time series prediction in PyTorch. We’re going to use pytorch’s
nn module so it’ll be pretty simple, but in case it doesn’t work on your computer, you can try the tips I’ve listed at the end that have helped me fix wonky LSTMs in the past.
What is an LSTM?
A Long-short Term Memory network (LSTM) is a type of recurrent neural network designed to overcome problems of basic RNNs so the network can learn long-term dependencies. Specifically, it tackles vanishing and exploding gradients – the phenomenon where, when you backpropagate through time too many time steps, the gradients either vanish (go to zero) or explode (get very large) because it becomes a product of numbers all greater or all less than one. You can learn more about LSTMs from Chris Olah’s excellent blog post. You can also read Hochreiter and Schmidhuber’s original paper (1997), which identifies the vanishing and exploding gradient problems and proposes the LSTM as a way of tackling those problems.
Next, let’s build the network.
In PyTorch, you usually build your network as a class inheriting from
nn.Module. You need to implement the
forward(.) method, which is the forward pass. You then run the forward pass like this:
# Define model model = LSTM(...) # Forward pass ypred = model(X_batch) # this is the same as model.forward(X_batch)
You can implement the LSTM from scratch, but here we’re going to use
torch.nn is a bit like Keras – it’s a wrapper around lower-level PyTorch code that makes it faster to build models by giving you common layers so you don’t have to implement them yourself.
# Here we define our model as a class class LSTM(nn.Module): def __init__(self, input_dim, hidden_dim, batch_size, output_dim=1, num_layers=2): super(LSTM, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.batch_size = batch_size self.num_layers = num_layers # Define the LSTM layer self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers) # Define the output layer self.linear = nn.Linear(self.hidden_dim, output_dim) def init_hidden(self): # This is what we'll initialise our hidden state as return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim), torch.zeros(self.num_layers, self.batch_size, self.hidden_dim)) def forward(self, input): # Forward pass through LSTM layer # shape of lstm_out: [input_size, batch_size, hidden_dim] # shape of self.hidden: (a, b), where a and b both # have shape (num_layers, batch_size, hidden_dim). lstm_out, self.hidden = self.lstm(input.view(len(input), self.batch_size, -1)) # Only take the output from the final timetep # Can pass on the entirety of lstm_out to the next layer if it is a seq2seq prediction y_pred = self.linear(lstm_out[-1].view(self.batch_size, -1)) return y_pred.view(-1) model = LSTM(lstm_input_size, h1, batch_size=num_train, output_dim=output_dim, num_layers=num_layers)
Training the LSTM
After defining the model, we define the loss function and optimiser and train the model:
loss_fn = torch.nn.MSELoss(size_average=False) optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate) ##################### # Train model ##################### hist = np.zeros(num_epochs) for t in range(num_epochs): # Clear stored gradient model.zero_grad() # Initialise hidden state # Don't do this if you want your LSTM to be stateful model.hidden = model.init_hidden() # Forward pass y_pred = model(X_train) loss = loss_fn(y_pred, y_train) if t % 100 == 0: print("Epoch ", t, "MSE: ", loss.item()) hist[t] = loss.item() # Zero out gradient, else they will accumulate between epochs optimiser.zero_grad() # Backward pass loss.backward() # Update parameters optimiser.step()
Debugging RNNs in PyTorch
Setting up and training models can be very simple in PyTorch. However, sometimes RNNs can predict values very close to zero even when the data isn’t distributed like that. I’ve found the following tricks have helped:
- Try decreasing your learning rate if your loss is increasing, or increasing your learning rate if the loss is not decreasing.
- Try removing
model.zero_grad()if you’re using that.
- Use nn.LSTMCell instead of nn.LSTM. (This is a weird one but it’s worked before.)
- Use more data if you can.
Hope this helps and all the best with your machine learning endeavours!
- LSTM for Time Series in PyTorch code
- Chris Olah’s blog post on understanding LSTMs
- LSTM paper (Hochreiter and Schmidhuber, 1997)
- An example of an LSTM implemented using nn.LSTMCell (from pytorch/examples)
- Feature Image Cartoon ‘Short-Term Memory’ by ToxicPaprika.