Long short-term memory (LSTM ) is a type of RNN specifically designed to process sequential data. Compared to standard RNNs, LSTM networks are able to maintain useful long-term dependencies for making predictions in the current and future time steps.
The complete code for this chapter can be found in .
Table of Contents
LSTM
Standard RNNs suffer from the problem of vanishing gradients. Therefore, if the sequence data is very long, the standard RNN cannot effectively learn the early input data. That is to say, the long-term memory capacity of standard RNN is quite weak. LSTM was developed to solve the problem of long-term memory.
For more details about vanishing gradients, please refer to the following article. In addition, if you are not familiar with RNN, please refer to the following article first.
The figure below is an LSTM cell. It is much more complicated than the standard RNN. In addition to the hidden state , LSTM also has a memory cell state
. In addition, the activation function in the standard RNN is only a
tanh
, while the LSTM needs to calculate the forget gate , input gate
, candidate cell state
, and output gate
.
The following are the meanings of these gates and states.
- Hidden State State
: This is the state of the current sequence and is used to make predictions.
- Memory Cell State
: This is the long-term memory of LSTM.
- Forget Gate
: This gate determines which information from the previous cell state should be discarded.
- Input Gate
: This gate determines which new information in the current input should be added to the current cell state.
- Candidate Cell State
: By processing the current input, new information is generated that may be useful to the network’s memory.
- Output Gate
: This gate determines the next hidden state.
Forward Propagation
The following figure shows the LSTM forward propagation. Each gate and candidate cell state has its corresponding parameters and activation functions. It is worth noting that we will stack
and
vertically.
The formula in the LSTM cell is as follows:
The dimensions of the LSTM input and true labels
are as follows:
In the LSTM cell, the dimensions of each variable are as follows:
The following is the implementation of LSTM forward propagation.
class LSTM: def cell_forward(self, xt, at_prev, ct_prev, parameters): """ Implement a single forward step of the LSTM-cell. Parameters ---------- xt: (ndarray (n_x, m)) - input data at time step "t" at_prev: (ndarray (n_a, m)) - hidden state at time step "t-1" ct_prev: (ndarray (n_a, m)) - memory state at time step "t-1" parameters: (dict) - dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate bf: (ndarray (n_a, 1)) - bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the input gate bi: (ndarray (n_a, 1)) - bias of the input gate Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value bc: (ndarray (n_a, 1)) - bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate bo: (ndarray (n_a, 1)) - bias of the output gate Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output Returns ------- at: (ndarray (n_a, m)) - hidden state at time step "t" ct: (ndarray (n_a, m)) - memory state at time step "t" y_hat_t: (ndarray (n_y, m)) - prediction at time step "t" cache: (tuple) - values needed for the backward pass, contains (at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt) """ Wf, bf = parameters["Wf"], parameters["bf"] # forget gate weight and biases Wi, bi = parameters["Wi"], parameters["bi"] # input gate weight and biases Wc, bc = parameters["Wc"], parameters["bc"] # candidate value weight and biases Wo, bo = parameters["Wo"], parameters["bo"] # output gate weight and biases Wy, by = parameters["Wy"], parameters["by"] # prediction weight and biases concat = np.concatenate((at_prev, xt), axis=0) ft = sigmoid(Wf @ concat + bf) # forget gate it = sigmoid(Wi @ concat + bi) # update gate cct = tanh(Wc @ concat + bc) # candidate value ct = ft * ct_prev + it * cct # memory state ot = sigmoid(Wo @ concat + bo) # output gate at = ot * tanh(ct) # hidden state zyt = Wy @ at + by y_hat_t = softmax(zyt) cache = (at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt) return at, ct, y_hat_t, cache def forward(self, X, a0, parameters): """ Implement the forward pass of the LSTM network. Parameters ---------- X: (ndarray (n_x, m, T_x)) input data for each time step a0: (ndarray (n_a, m)) initial hidden state parameters: (dict) dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) weight matrix of the forget gate bf: (ndarray (n_a, 1)) bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) weight matrix of the update gate bi: (ndarray (n_a, 1)) bias of the update gate Wc: (ndarray (n_a, n_a + n_x)) weight matrix of the candidate value bc: (ndarray (n_a, 1)) bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) weight matrix of the output gate bo: (ndarray (n_a, 1)) bias of the output gate Wy: (ndarray (n_y, n_a)) weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) bias relating the hidden-state to the output Returns ------- A: (ndarray (n_a, m, T_x)) - hidden states for each time step C: (ndarray (n_a, m, T_x)) - memory states for each time step Y_hat: (ndarray (n_y, m, T_x)) - predictions for each time step caches: (list) - values needed for the backward pass """ caches = [] Wy = parameters["Wy"] x_x, m, T_x = X.shape n_y, n_a = Wy.shape A = np.zeros((n_a, m, T_x)) C = np.zeros((n_a, m, T_x)) Y_hat = np.zeros((n_y, m, T_x)) at_prev = a0 ct_prev = np.zeros((n_a, m)) for t in range(T_x): at_prev, ct_prev, y_hat_t, cache = self.cell_forward(X[:, :, t], at_prev, ct_prev, parameters) A[:, :, t] = at_prev C[:, :, t] = ct_prev Y_hat[:, :, t] = y_hat_t caches.append(cache) return A, C, Y_hat, caches
Loss Function
In this article, we use softmax
as output , so we use cross-entropy loss as its loss function. For the formula and implementation of cross-entropy loss, please refer to the following article.
Backward Propagation
The backpropagation of LSTM is a bit complicated. We have to calculate the partial derivatives with respect to each parameter. In particular, when calculating and
, we also need to consider their values at the previous timestep, which is backpropagation through time (BPTT).
The following calculates the partial derivatives in the output layer.
The following is to calculate the partial derivatives of the forget gate, input gate, candidate cell state, and output gate.
The following calculates the partial derivatives of all parameters.
The following calculates the remaining partial derivatives.
The above is how to calculate all partial derivatives at each time step. Finally, we have to sum up all the partial derivatives we have taken.
The following is the implementation of LSTM backward propagation.
class LSTM: def cell_backward(self, y, dat, dct, cache, parameters): """ Implement the backward pass for the LSTM-cell. Parameters ---------- y: (ndarray (n_y, m)) - true labels for time step "t" dat: (ndarray (n_a, m)) - hidden state gradient for time step "t" dct: (ndarray (n_a, m)) - memory state gradient for time step "t" cache: (tuple) - values from the forward pass at time step "t" parameters: (dict) dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate bf: (ndarray (n_a, 1)) - bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the update gate bi: (ndarray (n_a, 1)) - bias of the update gate Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value bc: (ndarray (n_a, 1)) - bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate bo: (ndarray (n_a, 1)) - bias of the output gate Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output Returns ------- gradients: (dict) - dictionary containing the following gradients: dWf: (ndarray (n_a, n_a + n_x)) gradient of the forget gate weight dbf: (ndarray (n_a, 1)) gradient of the forget gate bias dWi: (ndarray (n_a, n_a + n_x)) gradient of the update gate weight dbi: (ndarray (n_a, 1)) gradient of the update gate bias dWc: (ndarray (n_a, n_a + n_x)) gradient of the candidate value weight dbc: (ndarray (n_a, 1)) gradient of the candidate value bias dWo: (ndarray (n_a, n_a + n_x)) gradient of the output gate weight dbo: (ndarray (n_a, 1)) gradient of the output gate bias dWy: (ndarray (n_y, n_a)) gradient of the prediction weight dby: (ndarray (n_y, 1)) gradient of the prediction bias dat_prev: (ndarray (n_a, m)) gradient of the hidden state for time step "t-1" dct_prev: (ndarray (n_a, m)) gradient of the memory state for time step "t-1" dxt: (ndarray (n_x, m)) gradient of the input data for time step "t" """ at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt = cache n_a, m = at.shape dzy = y_hat_t - y dWy = dzy @ at.T dby = np.sum(dzy, axis=1, keepdims=True) dat = parameters["Wy"].T @ dzy + dat dot = dat * tanh(ct) * ot * (1 - ot) dcct = (dct + dat * ot * (1 - tanh(ct) ** 2)) * it * (1 - cct ** 2) dit = (dct + dat * ot * (1 - tanh(ct) ** 2)) * cct * it * (1 - it) dft = (dct + dat * ot * (1 - tanh(ct) ** 2)) * ct_prev * ft * (1 - ft) concat = np.concatenate((at_prev, xt), axis=0) dWo = dot @ concat.T dbo = np.sum(dot, axis=1, keepdims=True) dWc = dcct @ concat.T dbc = np.sum(dcct, axis=1, keepdims=True) dWi = dit @ concat.T dbi = np.sum(dit, axis=1, keepdims=True) dWf = dft @ concat.T dbf = np.sum(dft, axis=1, keepdims=True) dat_prev = ( parameters["Wo"][:, :n_a].T @ dot + parameters["Wc"][:, :n_a].T @ dcct + parameters["Wi"][:, :n_a].T @ dit + parameters["Wf"][:, :n_a].T @ dft ) dct_prev = dct * ft + ot * (1 - tanh(ct) ** 2) * ft * dat dxt = ( parameters["Wo"][:, n_a:].T @ dot + parameters["Wc"][:, n_a:].T @ dcct + parameters["Wi"][:, n_a:].T @ dit + parameters["Wf"][:, n_a:].T @ dft ) gradients = { "dWo": dWo, "dbo": dbo, "dWc": dWc, "dbc": dbc, "dWi": dWi, "dbi": dbi, "dWf": dWf, "dbf": dbf, "dWy": dWy, "dby": dby, "dct_prev": dct_prev, "dat_prev": dat_prev, "dxt": dxt, } return gradients def backward(self, X, Y, parameters, caches): """ Implement the backward pass for the LSTM network. Parameters ---------- X: (ndarray (n_x, m, T_x)) input data for each time step Y: (ndarray (n_y, m, T_x)) true labels for each time step parameters: (dict) dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) weight matrix of the forget gate bf: (ndarray (n_a, 1)) bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) weight matrix of the update gate bi: (ndarray (n_a, 1)) bias of the update gate Wc: (ndarray (n_a, n_a + n_x)) weight matrix of the candidate value bc: (ndarray (n_a, 1)) bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) weight matrix of the output gate bo: (ndarray (n_a, 1)) bias of the output gate Wy: (ndarray (n_y, n_a)) weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) bias relating the hidden-state to the output caches: (list) values needed for the backward pass Returns ------- gradients: (dict) - dictionary containing the following gradients: dWf: (ndarray (n_a, n_a + n_x)) gradient of the forget gate weight dbf: (ndarray (n_a, 1)) gradient of the forget gate bias dWi: (ndarray (n_a, n_a + n_x)) gradient of the update gate weight dbi: (ndarray (n_a, 1)) gradient of the update gate bias dWc: (ndarray (n_a, n_a + n_x)) gradient of the candidate value weight dbc: (ndarray (n_a, 1)) gradient of the candidate value bias dWo: (ndarray (n_a, n_a + n_x)) gradient of the output gate weight dbo: (ndarray (n_a, 1)) gradient of the output gate bias dWy: (ndarray (n_y, n_a)) gradient of the prediction weight dby: (ndarray (n_y, 1)) gradient of the prediction bias """ n_x, m, T_x = X.shape a1, c1, a0, c0, f1, i1, cc1, o1, x1, y_hat_1, zy1 = caches[0] Wf, Wi, Wc, Wo, Wy = parameters["Wf"], parameters["Wi"], parameters["Wc"], parameters["Wo"], parameters["Wy"] bf, bi, bc, bo, by = parameters["bf"], parameters["bi"], parameters["bc"], parameters["bo"], parameters["by"] gradients = { "dWf": np.zeros_like(Wf), "dWi": np.zeros_like(Wi), "dWc": np.zeros_like(Wc), "dWo": np.zeros_like(Wo), "dbf": np.zeros_like(bf), "dbi": np.zeros_like(bi), "dbc": np.zeros_like(bc), "dbo": np.zeros_like(bo), "dWy": np.zeros_like(Wy), "dby": np.zeros_like(by), } dat = np.zeros_like(a0) dct = np.zeros_like(c0) for t in reversed(range(T_x)): grads = self.cell_backward(Y[:, :, t], dat, dct, caches[t], parameters) gradients["dWf"] += grads["dWf"] gradients["dWi"] += grads["dWi"] gradients["dWc"] += grads["dWc"] gradients["dWo"] += grads["dWo"] gradients["dbf"] += grads["dbf"] gradients["dbi"] += grads["dbi"] gradients["dbc"] += grads["dbc"] gradients["dbo"] += grads["dbo"] gradients["dWy"] += grads["dWy"] gradients["dby"] += grads["dby"] dat = grads["dat_prev"] dct = grads["dct_prev"] return gradients
Putting All Together
The following code implements a complete training process. First, we pass the training data into forward propagation, calculate the loss, and then pass it into backward propagation to finally get the gradients. To prevent exploding gradients from happening, we will clip the gradients. Then, use it to update the parameters. This is a complete training.
class LSTM: def optimize(self, X, Y, a_prev, parameters, learning_rate, clip_value): """ Implements the forward and backward propagation of the LSTM. Parameters ---------- X: (ndarray (n_x, m, T_x)) - input data for each time step Y: (ndarray (n_y, m, T_x)) - true labels for each time step a_prev: (ndarray (n_a, m)) - initial hidden state parameters: (dict) - dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate bf: (ndarray (n_a, 1)) - bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the update gate bi: (ndarray (n_a, 1)) - bias of the update gate Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value bc: (ndarray (n_a, 1)) - bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate bo: (ndarray (n_a, 1)) - bias of the output gate Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output learning_rate: (float) - learning rate for the optimization algorithm clip_value: (float) - maximum value for the gradients Returns ------- at: (ndarray (n_a, m)) hidden state for the last time step loss: (float) - the cross-entropy """ A, C, Y_hat, caches = self.forward(X, a_prev, parameters) loss = self.compute_loss(Y_hat, Y) gradients = self.backward(X, Y, parameters, caches) gradients = self.clip(gradients, clip_value) self.update_parameters(parameters, gradients, learning_rate) at = A[:, :, -1] return at, loss
Example
Next, we design LSTM as a character-level language model. The training material is a passage from Shakespeare. It trains one character at a time, so the sequence length will be the length of the input character, and uses one-hot encoding to encode each character. Please refer to the following article for details of this part, because this article and the following article use the same example.
An example of using this LSTM is as follows:
if __name__ == "__main__": with open("shakespeare.txt", "r") as file: text = file.read() chars = sorted(list(set(text))) vocab_size = len(chars) char_to_idx = {ch: i for i, ch in enumerate(chars)} idx_to_char = {i: ch for i, ch in enumerate(chars)} lstm = LSTM(64, vocab_size, vocab_size) losses = lstm.train(text, char_to_idx, num_iterations=100, learning_rate=0.01, clip_value=5) generated_text = lstm.sample("T", char_to_idx, idx_to_char, num_chars=100) print(generated_text)
Conclusion
Standard RNNs face the problem of vanishing gradients when processing long sequence data. LSTM was developed to address this problem. It has been used in speech recognition and machine translation, with quite good results.
References
- S. Hochreiter and J. Schmidhuber. 1997. Long Short-term Memory. Neural Computation, 9(8): 1735-1780.
- Andrew Ng, Deep Learning Specialization, Coursera.