Introduction to RNN and LSTM
Updated: Jun 7, 2022
The article dives deep into the working principles of the Recurrent Neural Network(RNN) and Long Short-Term Memory(LSTM).
“Humans don’t start their thinking from scratch every second. As you read this article, you understand each word based on your understanding of previous words. You don’t throw everything away and start thinking from scratch again. Your thoughts have persistence.” — Source
We have already seen in Introduction to Artificial Neural Networks(ANN) how ANN can be used for regression and classification tasks, and in Introduction to Convolutional Neural Networks(CNN) how CNN can be used for image recognition, segmentation or object detection and computer-vision related tasks.
But what if we have sequential data?
The problem with convolutional neural networks(CNN or ConvNet)
You might be wondering by now, we have vanilla networks like Convolutional ones that perform very well. Why do we need another type of network? There is a very specific use case where RNNs are required. To explain RNNs, you need to first understand something called a sequence. Let’s talk about sequences first.
The sequence is a stream of data (finite or infinite) which are interdependent. Examples would be time-series data, informative pieces of strings, conversations, etc. In a conversation a sentence means something but the entire flow of the conversation mostly means something completely different. Also in time-series data like stock market data, single tick data means the current price, but full days data will show movement and allow us to decide whether to buy or sell.
CNN’s generally don’t perform well when the input data is interdependent in a sequential pattern. CNN’s don’t have any sort of correlation between the previous input to the next input. So all the outputs are self-dependent. CNN takes input and outputs based on the trained model. If you run 100 different inputs none of them would be biased by the previous output. But imagine a scenario like a sentence generation or text translation. All the words generated are dependent on the words generated before (in certain cases, it’s dependent on words coming after as well, but we will discuss that later).
So you need to have some bias based on your previous output. This is where RNNs shine. RNNs have in them a sense of some memory about what happened earlier in the sequence of data. This helps the system to gain context. They are the first of its kind State of the Art algorithms that can memorize/remember previous inputs in-memory when a huge set of sequential data is provided.
What are Recurrent Neural Networks(or RNN)?
Recurrent Neural Network remembers the past and its decisions are influenced by what it has learned from the past. Note: Basic feedforward networks “remember” things too, but they remember things they learned during training. For example, an image classifier learns what a “1” looks like during training and then uses that knowledge to classify things in production.
While RNNs learn similarly while training, besides, they remember things learned from prior input(s) while generating output(s). It’s part of the network. RNNs can take one or more input vectors and produce one or more output vectors and the output(s) are influenced not just by weights applied on inputs like a regular NN, but also by a “hidden” state vector representing the context based on prior input(s)/output(s). So, the same input could produce a different output depending on previous inputs in the series.
RNNs are called recurrent because they perform the same task for every element of a sequence, with the output being dependent on the previous computations and you already know that they have a “memory” which captures information about what has been calculated so far.
For example, If you want to predict the next word in a sentence for that you better know which words came before it.
A typical RNN allows previous outputs to be used as inputs while having hidden states.
A Recurrent Neural Network with input X and output Y with multiple recurrent steps and a hidden unit.
This may be browbeating at first sight, but once unfolded, into a full network it looks a lot simpler.
The below diagram shows an RNN being unrolled (or unfolded). By unrolling we simply mean that we write out the network for the complete sequence. For example, if the sequence we care about is a sentence of 5 words, the network would be unrolled into a 5-layer neural network, one layer for each word.
This chain-like nature unveils that recurrent neural networks are closely related to sequences and lists. They’re the natural structure of the neural network to use for such data.
Types of RNN
RNN models are mostly used in the fields of natural language processing(NLP) and speech recognition. The main reason that the recurrent nets are more exciting is that they allow us to operate over sequences of vectors: Sequence in the input, the output, or in the most general case, both. A few examples may this more concrete:
This is also called Plain Neural networks. It deals with a fixed size of the input to the fixed size of output, where they are independent of previous information/output.
Example: Image classification.
It deals with a fixed size of information as input that gives a sequence of data as output.
Example: Image Captioning takes the image as input and outputs a sentence of words.
It takes a sequence of information as input and outputs a fixed size of the output.
Example: sentiment analysis where any sentence is classified as expressing the positive or negative sentiment.
It takes a Sequence of information as input and processes the recurrently outputs as a Sequence of data.
Example: Machine Translation, where the RNN reads any sentence in English and then outputs the sentence in French.
Synced sequence input and output. Notice that in every case are no pre-specified constraints on the lengths sequences because the recurrent transformation (green) is fixed and can be applied as many times as we like.
Example: Video classification where we wish to label every frame of the video.
How RNN Work?
Consider an unfolded RNN:
The formula for the current state can be written as –
Here, Ht is the new state, ht-1 is the previous state while xt is the current input. We now have a state of the previous input instead of the input itself, because the input neuron would have applied the transformations on our previous input. So each successive input is called a time step.
In this case, we have four inputs to be given to the network, during a recurrence formula, the same function and the same weights are applied to the network at each time step.
Taking the simplest form of a recurrent neural network, let’s say that the activation function is tanh, the weight at the recurrent neuron is Whh, and the weight at the input
the neuron is Wxh, we can write the equation for the state at the time t as –
The Recurrent neuron, in this case, is just considering the immediately previous state. For longer sequences, the equation can involve multiple such states. Once the final state is calculated we can go on to produce the output.
Now, once the current state is calculated we can calculate the output state as-
Let me summarize the steps in a recurrent neuron for you-
A single time step of the input is supplied to the network i.e. xt is supplied to the network
We then calculate its current state using a combination of the current input and the previous state i.e. we calculate ht
The current ht becomes ht-1 for the next time step
We can go as many time steps as the problem demands and combine the information from all the previous states
Once all the time steps are completed the final current state is used to calculate the output yt
The output is then compared to the actual output and the error is generated
The error is then backpropagated to the network to update the weights(we shall go into the details of backpropagation in further sections) and the network is trained.
Backpropagation Through Time(BPTT)
Training an RNN is similar to training a traditional Neural Network. We also use the backpropagation algorithm, but with a little twist. Because the parameters are shared by all-time steps in the network, the gradient at each output depends not only on the calculations of the current time step but also on the previous time steps.
For example, to calculate the gradient at t = 4, we would need to backpropagate 3 steps and sum up the gradients. This is called Backpropagation Through Time (BPTT). The vanilla RNNs trained with BPTT have difficulties learning long-term dependencies (e.g. dependencies between steps that are far apart) due to what is called the vanishing/exploding gradient problem. There exists some machinery to deal with these problems, and certain types of RNNs (like LSTMs) were specifically designed to get around them.
The Problem of Long-Term Dependencies
Unfortunately, if you implement the above steps, you won’t be so delighted with the results. That is because the simplest RNN model has two major drawbacks, called vanishing gradient and gradient exploding, which prevents it from being accurate.
If the sequence is long enough he’ll have a hard time carrying information from earlier time steps to later ones. So if you are trying to process a paragraph of text to do predictions, RNNs may leave out important information from the beginning.
For example, RNNs is fine when dealing with short term dependencies. That is when used to sequence like:
This sequence has nothing to do with the context of the statement. The RNNs need not remember what was said before this or whats its meaning, all they need to know is that amongst all rivers, Nile is the longest.
For instance, if we have a sentence like:
In this case, the description of purple hair is for the man and not the pizza. So this is a long dependency.
If we backpropagate the error in this case, we would need to apply the chain rule. To calculate the error after the third time step concerning the first one –
∂E/∂W = ∂E/∂y3 *∂y3/∂h3 *∂h3/∂y2 *∂y2/∂h1 .. and there is a long dependency.
Here we apply the chain rule and if any one of the gradients approached 0, all the gradients would rush to zero exponentially fast due to the multiplication. They shrink exponentially until they vanish and make it impossible for the model to learn, Such states would no longer help the network to learn anything. This is known as the vanishing gradient problem.
While on the other hand if gradients would rush to large values (>1) they get larger and eventually blow up and crash the model, In this case, RNNs assign stupidly high importance to the weights without much reason. This is called the exploding gradient problem.
When gradients explode, the gradients could become NaN because of the numerical overflow or we might see irregular oscillations in training cost when we plot the learning curve. A solution to fix this is to apply gradient clipping; which places a predefined threshold on the gradients to prevent it from getting too large, and by doing this it doesn’t change the direction of the gradients it only changes its length.
The vanishing gradient problem is far more threatening as compared to the exploding gradient problem, where the gradients become very very large due to a single or multiple gradient values becoming very high. The reason being the exploding gradient problem can be easily solved by clipping the gradients at a predefined threshold value.
Fortunately, there are ways to handle the vanishing gradient problem as well by using architectures like the LSTM(Long Short term memory) and the GRU(Gated Recurrent Units).
Long Short-Term Memory(LSTM)
LSTM is an improved version of the regular RNN which was designed to make it easy to capture long-term dependencies in sequence data. A regular RNN functions in such a way that the hidden state activation is influenced by the other local activations nearest to them, which corresponds to a “short-term memory”, while the network weights are influenced by the computations that take place over entire long sequences, which corresponds to “long-term memory”. Hence the RNN was redesigned so that it has an activation state that can also act as weights and preserve information over long distances, hence the name “Long Short-Term Memory”.
LSTMs are explicitly designed to avoid the long-term dependency problem. Remembering information for long periods is practically their default behavior, not something they struggle to learn!
The Core Idea Behind LSTMs
The key to LSTMs is the cell state, the horizontal line running through the top of the diagram.
The cell state is kind of like a conveyor belt. It runs straight down the entire chain, with only some minor linear interactions. It’s very easy for information to just flow along with it unchanged.
The black line represents cell state
The LSTM does have the ability to remove or add information to the cell state, carefully regulated by structures called gates.
Gates are a way to optionally let information through. They are composed out of a sigmoid neural net layer and a pointwise multiplication operation.
The sigmoid layer outputs numbers between zero and one, describing how much of each component should be let through. A value of zero means “let nothing through,” while a value of one means “let everything through!”
If you look at the architecture of LSTM you can observe that it has three gates, to protect and control the cell state:
1) Forget Gate.
2) Input Gate.
3) Output Gate.
Gates in LSTM are the sigmoid activation functions i.e they output a value between 0 or 1 and in most of the cases, it is either 0 or 1. we use the sigmoid function for gates because we want a gate to give only positive values and should be able to give us a clear cut answer whether we need to keep a particular feature or we need to discard that feature.
Note: Notations used in the equations:
It decides how much of the past information the LSTM should remember(which information to omit in a particular tie step). This decision is taken by the sigmoid activation function. It looks at the previous state h_t-1 and current input X_t and outputs a number between 0(omit this information) and 1(keep this information) for each number in the cell state Ct-1 .
h_t-1 is the hidden state from the previous cell or the output of the previous cell and x_t is the input at that particular time step. The given inputs are multiplied by the weight matrices and a bias is added. Following this, the sigmoid function is applied to this value. The sigmoid function outputs a vector, with values ranging from 0 to 1, corresponding to each number in the cell state. The sigmoid function is responsible for deciding which values to keep and which to discard. If a ‘0’ is output for a particular value in the cell state, it means that the forget gate wants the cell state to forget that piece of information completely. Similarly, a ‘1’ means that the forget gate wants to remember that entire piece of information. This vector output from the sigmoid function is multiplied by the cell state.
The input gate decides how much of this unit is added to the current state.
The input gate is responsible for the addition of information to the cell state. This addition of information is a three-step process as seen from the diagram above.
Regulating what values need to be added to the cell state by involving a sigmoid function. This is very similar to the forget gate and acts as a filter for all the information from h_t-1 and x_t.
Creating a vector containing all possible values that can be added (as perceived from h_t-1 and x_t) to the cell state. This is done using the tanh function, which outputs values from -1 to +1.
Multiplying the value of the regulatory filter (the sigmoid gate) to the created vector (the tanh function) and then adding this useful information to the cell state via addition operation.
It’s now time to update the old cell state, Ct−1, into the new cell state Ct. The previous steps already decided what to do, we just need to actually do it.
We multiply the old state by ft, forgetting the things we decided to forget earlier. Then we add it∗C̃ . This is the new candidate values, scaled by how much we decided to update each state value.
The output gate decides which part of the current cell makes it to the output.
The functioning of an output gate can again be broken down into three steps:
Creating a vector after applying the tanh function to the cell state, thereby scaling the values to the range -1 to +1.
Making a filter using the values of h_t-1 and xt, such that it can regulate the values that need to be output from the vector created above. This filter again employs a sigmoid function.
Multiplying the value of this regulatory filter to the vector created in step 1, and sending it out as output and also to the hidden state of the next cell.
LSTMs are a very promising solution to sequence and time series related problems. However, the one disadvantage that I find about them, is the difficulty in training them. A lot of time and system resources go into training even a simple model. In the future articles, we shall see how we can actually implement RNN/LSTM for NLP and time series forecasting problems.
That’s all for this article hope you guys have enjoyed reading it and I’ll be glad if the article is of any help. Feel free to share your comments/thoughts/feedback in the comment section.