/ 9 min read
Deriving LSTM Gradient for Backpropagation
Recurrent Neural Network (RNN) is hot in these past years, especially with the boom of Deep Learning. Just like any deep neural network, RNN can be seen as a (very) deep neural network if we “unroll” the network with respect of the time step. Hence, with all the things that enable vanilla deep network, training RNN become more and more feasible too.
The most popular model for RNN right now is the LSTM (Long Short-Term Memory) network. For the background theory, there are a lot of amazing resources available in Andrej Karpathy’s blog and Chris Olah’s blog.
Using modern Deep Learning libraries like TensorFlow, Torch, or Theano nowadays, building an LSTM model would be a breeze as we don’t need to analytically derive the backpropagation step. However to understand the model better, it’s absolutely a good thing, albeit optional, to try to derive the LSTM net gradient and implement the backpropagation “manually”.
So, here, we will try to first implement the forward computation step according to the LSTM net formula, then we will try to derive the network gradient analytically. Finally, we will implement it using numpy.
LSTM Forward
We will follow this model for a single LSTM cell:
Let’s implement it!
Above, we’re declaring our LSTM net model. Notice that from the formula above, we’re concatenating the old hidden state h
with current input x
, hence the input for our LSTM net would be Z = H + D
. And because our LSTM layer wants to output H
neurons, each weight matrices’ size would be ZxH
and each bias vectors’ size would be 1xH
.
One difference is for Wy
and by
. This weight and bias would be used for fully connected layer, which would be fed to a softmax layer. The resulting output should be a probability distribution over all possible items in vocabulary, which would be the size of 1xD
. Hence, Wy
’s size must be HxD
and by
’s size must be 1xD
.
The above code is for the forward step for a single LSTM cell, which identically follows the formula above. The only additions are the one-hot encoding and the hidden-input concatenation process.
LSTM Backward
Now, we will dive into the main point of this post: LSTM backward computation. We will assume that derivative function for sigmoid
and tanh
are already known.
A bit long isn’t it? However, actually it’s easy enough to derive the LSTM gradients if you understand how to take a partial derivative of a function and how to do chain rule, albeit some tricky stuffs are going on here. For this, I would recommend CS231n.
Things that are tricky and not-so-obvious when deriving the LSTM gradients are:
- Adding
dh_next
todh
, becauseh
is branched in forward propagation: it was used iny = h @ Wy + by
and the next time step, concatenated withx
. Hence the gradient is split and has to be added here. - Adding
dc_next
todc
. Identical reason with above. - Adding
dX = dXo + dXc + dXi + dXf
. Similar reason with above: X is used in many places so the gradient is split and need to be accumulated back. - Getting
dh_next
which is the gradient ofh_old
. AsX = [h_old, x]
, thendh_next
is just a reverse concatenation: split operation ondX
.
With the forward and backward computation implementations in hands, we could stitch them together to get a full training step that would be useful for optimization algorithms.
LSTM Training Step
This training step consists of three steps: forward computation, loss calculation, and backward computation.
In the full training step, first we’re do full forward propagation on all items in training set, then store the results which are the softmax probabilities and cache of each timestep into a list, because we are going to use it in backward step.
Next, at each timestep, we can calculate the cross entropy loss (because we’re using softmax). We then accumulate all of those loss in every timestep, then average them.
Lastly, we do backpropagation based on the forward step results. Notice while we’re iterating the data forward in forward step, we’re going the reverse direction here.
Also notice that dh_next
and dc_next
for the first timestep in backward step is zero. Why? This is because at the last timestep in forward propagation, h
and c
won’t be used in the next timestep, as there are no more timestep! So, the gradient of h
and c
in the last timestep are not split and could be derived directly without dh_next
and dc_next
.
With this function in hands, we could plug this to any optimization algorithm like RMSProp, Adam, etc with some modification. Namely, we have to take account on the state of the network. So, the state for the current timestep need to be passed to the next timestep.
And, that’s it. We can train our LSTM net now!
Test Result
Using Adam to optimize the network, here’s the result when I feed a copy-pasted text about Japan from Wikipedia. Each data is a character in the text. The target is the next character.
After each 100 iterations, the network are sampled.
It works like this:
- Do forward propagation and get the softmax distribution
- Sample the distribution
- Feed the sampled character as the input of next time step
- Repeat
And here’s the snippet of the results:
Our network definitely learned something here!
Conclusion
Here, we looked at the general formula for LSTM and implement the forward propagation step based on it, which is very straightforward to do.
Then, we derived the backward computation step. This step was also straightforward, but there were some tricky stuffs that we had to ponder about, especially the recurrency step in h
and c
.
We then stitched the forward and backward step together to build the full training step that can be used with any optimization algorithm.
Lastly, we tried to run the network using some test data and showed that the network was learning by looking at the loss value and the sample of text that are produced by the network.