Long Short-Term Memory (LSTM)

Photo by Sinziana Susa on Unsplash
Photo by Sinziana Susa on Unsplash
Long short-term (LSTM ) is a type of RNN specifically designed to process sequential data. Compared to standard RNNs, LSTM networks are able to maintain useful long-term dependencies for making predictions in the current and future time steps.

Long short-term memory (LSTM ) is a type of RNN specifically designed to process sequential data. Compared to standard RNNs, LSTM networks are able to maintain useful long-term dependencies for making predictions in the current and future time steps.

The complete code for this chapter can be found in .

LSTM

Standard RNNs suffer from the problem of vanishing gradients. Therefore, if the sequence data is very long, the standard RNN cannot effectively learn the early input data. That is to say, the long-term memory capacity of standard RNN is quite weak. LSTM was developed to solve the problem of long-term memory.

For more details about vanishing gradients, please refer to the following article. In addition, if you are not familiar with RNN, please refer to the following article first.

The figure below is an LSTM cell. It is much more complicated than the standard RNN. In addition to the hidden state a, LSTM also has a memory cell state c. In addition, the activation function in the standard RNN is only a tanh, while the LSTM needs to calculate the forget gate \Gamma_f, input gate \Gamma_i, candidate cell state \tilde{c}, and output gate \Gamma_o.

LSTM.
LSTM.

The following are the meanings of these gates and states.

  • Hidden State State a^{<t>}: This is the state of the current sequence and is used to make predictions.
  • Memory Cell State c^{<t>}: This is the long-term memory of LSTM.
  • Forget Gate \Gamma_f^{<t>}: This gate determines which information from the previous cell state should be discarded.
  • Input Gate \Gamma_i^{<t>}: This gate determines which new information in the current input should be added to the current cell state.
  • Candidate Cell State \tilde{c}^{<t>}: By processing the current input, new information is generated that may be useful to the network’s memory.
  • Output Gate \Gamma_o^{<t>}: This gate determines the next hidden state.

Forward Propagation

The following figure shows the LSTM forward propagation. Each gate and candidate cell state has its corresponding parameters W,b and activation functions. It is worth noting that we will stack a^{<t-1>} and x^{<t>} vertically.

LSTM Cell Forward.
LSTM Cell Forward.

The formula in the LSTM cell is as follows:

\Gamma_f^{<t>}=\sigma(W_f[a^{<t-1>},x^{<t>}]+b_f) \\\\ \Gamma_i^{<t>}=\sigma(W_i[a^{<t-1>},x^{<t>}]+b_i) \\\\ \tilde{c}^{<t>}=\tanh(W_c[a^{<t-1>},x^{<t>}]+b_c) \\\\ \Gamma_o^{<t>}=\sigma(W_o[a^{<t-1>},x^{<t>}]+b_o) \\\\ c^{<t>}=\Gamma_f^{<t>}\odot c^{<t-1>}+\Gamma_i^{<t>}\odot \tilde{c}^{<t>} \\\\ a^{<t>}=\Gamma_o^{<t>}\odot tanh(c^{<t>}) \\\\ \hat{y}^{<t>}=softmax(W_ya^{<t>}+b_y)

The dimensions of the LSTM input Xand true labels Y are as follows:

X:(n_x,m,T_x)-\text{the inputs.} \\\\ Y:(n_y,m,T_y)-\text{the true labels.} \\\\ m:\text{the number of examples.} \\\\ n_x:\text{the number of units in }x^{(i)<t>}. \\\\ n_y:\text{the number of units in }y^{(i)<t>}. \\\\ n_a:\text{the number of units in hidden state.} \\\\ x^{(i)}:\text{the input of }i\text{-th example} \\\\ T_x:\text{the input sequence length.} \\\\ T_y:\text{the output sequece length.}

In the LSTM cell, the dimensions of each variable are as follows:

a^{<t-1>}(n_a,m)x^{<t>}(n_x,m)
[a^{<t-1>},x^{<t>}](n_a+n_x,n_a)\hat{y}^{<t>}(n_y,m)
W_f(n_a,n_a+n_x)b_f(n_a,1)
W_i(n_a,n_a+n_x)b_i(n_a,1)
W_c(n_a,n_a+n_x)b_c(n_a,1)
W_o(n_a,n_a+n_x)b_o(n_a,1)
W_y(n_y,n_a)b_y(n_y,1)

The following is the implementation of LSTM forward propagation.

class LSTM:
    def cell_forward(self, xt, at_prev, ct_prev, parameters):
        """
        Implement a single forward step of the LSTM-cell.

        Parameters
        ----------
        xt: (ndarray (n_x, m)) - input data at time step "t"
        at_prev: (ndarray (n_a, m)) - hidden state at time step "t-1"
        ct_prev: (ndarray (n_a, m)) - memory state at time step "t-1"
        parameters: (dict) - dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) - bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the input gate
            bi: (ndarray (n_a, 1)) - bias of the input gate
            Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) - bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate
            bo: (ndarray (n_a, 1)) - bias of the output gate
            Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output

        Returns
        -------
        at: (ndarray (n_a, m)) - hidden state at time step "t"
        ct: (ndarray (n_a, m)) - memory state at time step "t"
        y_hat_t: (ndarray (n_y, m)) - prediction at time step "t"
        cache: (tuple) - values needed for the backward pass, contains (at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt)
        """

        Wf, bf = parameters["Wf"], parameters["bf"]  # forget gate weight and biases
        Wi, bi = parameters["Wi"], parameters["bi"]  # input gate weight and biases
        Wc, bc = parameters["Wc"], parameters["bc"]  # candidate value weight and biases
        Wo, bo = parameters["Wo"], parameters["bo"]  # output gate weight and biases
        Wy, by = parameters["Wy"], parameters["by"]  # prediction weight and biases

        concat = np.concatenate((at_prev, xt), axis=0)

        ft = sigmoid(Wf @ concat + bf)  # forget gate
        it = sigmoid(Wi @ concat + bi)  # update gate
        cct = tanh(Wc @ concat + bc)  # candidate value
        ct = ft * ct_prev + it * cct  # memory state
        ot = sigmoid(Wo @ concat + bo)  # output gate
        at = ot * tanh(ct)  # hidden state

        zyt = Wy @ at + by
        y_hat_t = softmax(zyt)
        cache = (at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt)
        return at, ct, y_hat_t, cache

    def forward(self, X, a0, parameters):
        """
        Implement the forward pass of the LSTM network.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) input data for each time step
        a0: (ndarray (n_a, m)) initial hidden state
        parameters: (dict) dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) weight matrix of the update gate
            bi: (ndarray (n_a, 1)) bias of the update gate
            Wc: (ndarray (n_a, n_a + n_x)) weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) weight matrix of the output gate
            bo: (ndarray (n_a, 1)) bias of the output gate
            Wy: (ndarray (n_y, n_a)) weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) bias relating the hidden-state to the output

        Returns
        -------
        A: (ndarray (n_a, m, T_x)) - hidden states for each time step
        C: (ndarray (n_a, m, T_x)) - memory states for each time step
        Y_hat: (ndarray (n_y, m, T_x)) - predictions for each time step
        caches: (list) - values needed for the backward pass
        """

        caches = []

        Wy = parameters["Wy"]
        x_x, m, T_x = X.shape
        n_y, n_a = Wy.shape

        A = np.zeros((n_a, m, T_x))
        C = np.zeros((n_a, m, T_x))
        Y_hat = np.zeros((n_y, m, T_x))

        at_prev = a0
        ct_prev = np.zeros((n_a, m))

        for t in range(T_x):
            at_prev, ct_prev, y_hat_t, cache = self.cell_forward(X[:, :, t], at_prev, ct_prev, parameters)
            A[:, :, t] = at_prev
            C[:, :, t] = ct_prev
            Y_hat[:, :, t] = y_hat_t
            caches.append(cache)

        return A, C, Y_hat, caches

Loss Function

In this article, we use softmax as output \hat{y}, so we use cross-entropy loss as its loss function. For the formula and implementation of cross-entropy loss, please refer to the following article.

Backward Propagation

The backpropagation of LSTM is a bit complicated. We have to calculate the partial derivatives with respect to each parameter. In particular, when calculating \frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}} and \frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}, we also need to consider their values ​​at the previous timestep, which is backpropagation through time (BPTT).

LSTM Cell Backward.
LSTM Cell Backward.

The following calculates the partial derivatives in the output layer.

\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\hat{y}^{<t>}-y^{<t>} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial Wya}=\frac{\partial z_y^{<t>}}{\partial Wya}\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}a^{<t>T} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial by}=\frac{\partial z_y^{<t>}}{\partial by}\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial a^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a^{<t>}}+\frac{\partial \mathcal{L}^{<t+1>}}{\partial a^{<t>}}=Wya^T\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}+\frac{\partial \mathcal{L}^{<t+1>}}{\partial a^{<t>}}

The following is to calculate the partial derivatives of the forget gate, input gate, candidate cell state, and output gate.

\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial\Gamma_o^{<t>}}\frac{\partial\Gamma_o^{<t>}}{\partial\gamma_o^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\cdot tanh(c^{<t>})\cdot\Gamma_o^{<t>}\cdot(1-\Gamma_o^{<t>})

\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\tilde{c}^{<t>}}\frac{\partial\tilde{c}^{<t>}}{\partial p\tilde{c}^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\tilde{c}^{<t>}}\frac{\partial\tilde{c}^{<t>}}{\partial p\tilde{c}^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}]\frac{\partial c^{<t>}}{\partial\tilde{c}^{<t>}}\frac{\partial\tilde{c}^{<t>}}{\partial p\tilde{c}^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\cdot\Gamma_o^{<t>}\cdot(1-tanh^2(c^{<t>}))]\cdot\Gamma_i^{<t>}\cdot(1-(\tilde{c}^{<t>})^2)

\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_i^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\Gamma_i^{<t>}}\frac{\partial\Gamma_i^{<t>}}{\partial\gamma_i^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\Gamma_i^{<t>}}\frac{\partial\Gamma_i^{<t>}}{\partial\gamma_i^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_i^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}]\frac{\partial c^{<t>}}{\partial\Gamma_i^{<t>}}\frac{\partial\Gamma_i^{<t>}}{\partial\gamma_i^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_i^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\cdot\Gamma_o^{<t>}\cdot(1-tanh^2(c^{<t>}))]\cdot\tilde{c}^{<t>} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_i^{<t>}}=}\cdot\Gamma_i^{<t>}\cdot(1-\Gamma_i^{<t>})

\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_f^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\Gamma_f^{<t>}}\frac{\partial\Gamma_f^{<t>}}{\partial\gamma_f^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\Gamma_f^{<t>}}\frac{\partial\Gamma_f^{<t>}}{\partial\gamma_f^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_f^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}]\frac{\partial c^{<t>}}{\partial\Gamma_f^{<t>}}\frac{\partial\Gamma_f^{<t>}}{\partial\gamma_f^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_f^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\cdot\Gamma_o^{<t>}\cdot(1-tanh^2(c^{<t>}))]\cdot c^{<t-1>} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_f^{<t>}}=}\cdot\Gamma_f^{<t>}\cdot(1-\Gamma_f^{<t>})

The following calculates W,bthe partial derivatives of all parameters.

\frac{\partial\mathcal{L}^{<t>}}{\partial W_o}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o}\frac{\partial\gamma_o}{\partial W_o}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o} \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial b_o}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o}\frac{\partial\gamma_o}{\partial b_o}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial W_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}}\frac{\partial p\tilde{c}}{\partial W_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}} \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial b_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}}\frac{\partial p\tilde{c}}{\partial b_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial W_i}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i}\frac{\partial\gamma_i}{\partial W_i}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i} \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial b_i}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i}\frac{\partial\gamma_i}{\partial b_i}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial W_f}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f}\frac{\partial\gamma_f}{\partial W_f}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f} \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial b_f}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f}\frac{\partial\gamma_f}{\partial b_f}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f}

The following calculates the remaining partial derivatives.

\text{let } W_o=[W_{o1},W_{o2}] \\\\ \text{let } W_c=[W_{c1},W_{c2}] \\\\ \text{let } W_i=[W_{i1},W_{i2}] \\\\ \text{let } W_f=[W_{f1},W_{f2}] \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t-1>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial c^{<t-1>}}+\frac{\partial\mathcal{L^{<t>}}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial c^{<t-1>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t-1>}}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\cdot\Gamma_f^{<t>}+\frac{\partial\mathcal{L^{<t>}}}{\partial a^{<t>}}\cdot\Gamma_o^{<t>}\cdot(1-tanh^2(c^{<t>}))\cdot\Gamma_f^{<t>} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t-1>}}=W_{o2}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o^{<t>}}+W_{c2}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_c^{<t>}}+W_{i2}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i^{<t>}}+W_{f2}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f^{<t>}} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial x^{<t>}}=W_{o1}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o^{<t>}}+W_{c1}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_c^{<t>}}+W_{i1}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i^{<t>}}+W_{f1}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f^{<t>}}

The above is how to calculate all partial derivatives at each time step. Finally, we have to sum up all the partial derivatives we have taken.

\displaystyle \frac{\partial\mathcal{L}}{\partial W_y}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_y},\frac{\partial\mathcal{L}}{\partial b_y}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial b_y} \\\\ \frac{\partial\mathcal{L}}{\partial W_o}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_o},\frac{\partial\mathcal{L}}{\partial b_o}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_o} \\\\ \frac{\partial\mathcal{L}}{\partial W_c}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_c},\frac{\partial\mathcal{L}}{\partial b_c}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_c} \\\\ \frac{\partial\mathcal{L}}{\partial W_i}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_i},\frac{\partial\mathcal{L}}{\partial b_i}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_i} \\\\ \frac{\partial\mathcal{L}}{\partial W_f}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_f},\frac{\partial\mathcal{L}}{\partial b_f}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_f} \\\\ \frac{\partial\mathcal{L}}{\partial a}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}} \\\\ \frac{\partial\mathcal{L}}{\partial c}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}

The following is the implementation of LSTM backward propagation.

class LSTM:
    def cell_backward(self, y, dat, dct, cache, parameters):
        """
        Implement the backward pass for the LSTM-cell.

        Parameters
        ----------
        y: (ndarray (n_y, m)) - true labels for time step "t"
        dat: (ndarray (n_a, m)) - hidden state gradient for time step "t"
        dct: (ndarray (n_a, m)) - memory state gradient for time step "t"
        cache: (tuple) - values from the forward pass at time step "t"
        parameters: (dict) dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) - bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the update gate
            bi: (ndarray (n_a, 1)) - bias of the update gate
            Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) - bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate
            bo: (ndarray (n_a, 1)) - bias of the output gate
            Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output

        Returns
        -------
        gradients: (dict) - dictionary containing the following gradients:
            dWf: (ndarray (n_a, n_a + n_x)) gradient of the forget gate weight
            dbf: (ndarray (n_a, 1)) gradient of the forget gate bias
            dWi: (ndarray (n_a, n_a + n_x)) gradient of the update gate weight
            dbi: (ndarray (n_a, 1)) gradient of the update gate bias
            dWc: (ndarray (n_a, n_a + n_x)) gradient of the candidate value weight
            dbc: (ndarray (n_a, 1)) gradient of the candidate value bias
            dWo: (ndarray (n_a, n_a + n_x)) gradient of the output gate weight
            dbo: (ndarray (n_a, 1)) gradient of the output gate bias
            dWy: (ndarray (n_y, n_a)) gradient of the prediction weight
            dby: (ndarray (n_y, 1)) gradient of the prediction bias
            dat_prev: (ndarray (n_a, m)) gradient of the hidden state for time step "t-1"
            dct_prev: (ndarray (n_a, m)) gradient of the memory state for time step "t-1"
            dxt: (ndarray (n_x, m)) gradient of the input data for time step "t"
        """

        at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt = cache
        n_a, m = at.shape

        dzy = y_hat_t - y
        dWy = dzy @ at.T
        dby = np.sum(dzy, axis=1, keepdims=True)

        dat = parameters["Wy"].T @ dzy + dat

        dot = dat * tanh(ct) * ot * (1 - ot)
        dcct = (dct + dat * ot * (1 - tanh(ct) ** 2)) * it * (1 - cct ** 2)
        dit = (dct + dat * ot * (1 - tanh(ct) ** 2)) * cct * it * (1 - it)
        dft = (dct + dat * ot * (1 - tanh(ct) ** 2)) * ct_prev * ft * (1 - ft)

        concat = np.concatenate((at_prev, xt), axis=0)

        dWo = dot @ concat.T
        dbo = np.sum(dot, axis=1, keepdims=True)
        dWc = dcct @ concat.T
        dbc = np.sum(dcct, axis=1, keepdims=True)
        dWi = dit @ concat.T
        dbi = np.sum(dit, axis=1, keepdims=True)
        dWf = dft @ concat.T
        dbf = np.sum(dft, axis=1, keepdims=True)

        dat_prev = (
            parameters["Wo"][:, :n_a].T @ dot + parameters["Wc"][:, :n_a].T @ dcct +
            parameters["Wi"][:, :n_a].T @ dit + parameters["Wf"][:, :n_a].T @ dft
        )
        dct_prev = dct * ft + ot * (1 - tanh(ct) ** 2) * ft * dat
        dxt = (
            parameters["Wo"][:, n_a:].T @ dot + parameters["Wc"][:, n_a:].T @ dcct +
            parameters["Wi"][:, n_a:].T @ dit + parameters["Wf"][:, n_a:].T @ dft
        )

        gradients = {
            "dWo": dWo, "dbo": dbo, "dWc": dWc, "dbc": dbc, "dWi": dWi, "dbi": dbi, "dWf": dWf, "dbf": dbf,
            "dWy": dWy, "dby": dby, "dct_prev": dct_prev, "dat_prev": dat_prev, "dxt": dxt,
        }
        return gradients

    def backward(self, X, Y, parameters, caches):
        """
        Implement the backward pass for the LSTM network.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) input data for each time step
        Y: (ndarray (n_y, m, T_x)) true labels for each time step
        parameters: (dict) dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) weight matrix of the update gate
            bi: (ndarray (n_a, 1)) bias of the update gate
            Wc: (ndarray (n_a, n_a + n_x)) weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) weight matrix of the output gate
            bo: (ndarray (n_a, 1)) bias of the output gate
            Wy: (ndarray (n_y, n_a)) weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) bias relating the hidden-state to the output
        caches: (list) values needed for the backward pass

        Returns
        -------
        gradients: (dict) - dictionary containing the following gradients:
            dWf: (ndarray (n_a, n_a + n_x)) gradient of the forget gate weight
            dbf: (ndarray (n_a, 1)) gradient of the forget gate bias
            dWi: (ndarray (n_a, n_a + n_x)) gradient of the update gate weight
            dbi: (ndarray (n_a, 1)) gradient of the update gate bias
            dWc: (ndarray (n_a, n_a + n_x)) gradient of the candidate value weight
            dbc: (ndarray (n_a, 1)) gradient of the candidate value bias
            dWo: (ndarray (n_a, n_a + n_x)) gradient of the output gate weight
            dbo: (ndarray (n_a, 1)) gradient of the output gate bias
            dWy: (ndarray (n_y, n_a)) gradient of the prediction weight
            dby: (ndarray (n_y, 1)) gradient of the prediction bias
        """

        n_x, m, T_x = X.shape
        a1, c1, a0, c0, f1, i1, cc1, o1, x1, y_hat_1, zy1 = caches[0]
        Wf, Wi, Wc, Wo, Wy = parameters["Wf"], parameters["Wi"], parameters["Wc"], parameters["Wo"], parameters["Wy"]
        bf, bi, bc, bo, by = parameters["bf"], parameters["bi"], parameters["bc"], parameters["bo"], parameters["by"]

        gradients = {
            "dWf": np.zeros_like(Wf), "dWi": np.zeros_like(Wi), "dWc": np.zeros_like(Wc), "dWo": np.zeros_like(Wo),
            "dbf": np.zeros_like(bf), "dbi": np.zeros_like(bi), "dbc": np.zeros_like(bc), "dbo": np.zeros_like(bo),
            "dWy": np.zeros_like(Wy), "dby": np.zeros_like(by),
        }

        dat = np.zeros_like(a0)
        dct = np.zeros_like(c0)
        for t in reversed(range(T_x)):
            grads = self.cell_backward(Y[:, :, t], dat, dct, caches[t], parameters)
            gradients["dWf"] += grads["dWf"]
            gradients["dWi"] += grads["dWi"]
            gradients["dWc"] += grads["dWc"]
            gradients["dWo"] += grads["dWo"]
            gradients["dbf"] += grads["dbf"]
            gradients["dbi"] += grads["dbi"]
            gradients["dbc"] += grads["dbc"]
            gradients["dbo"] += grads["dbo"]
            gradients["dWy"] += grads["dWy"]
            gradients["dby"] += grads["dby"]
            dat = grads["dat_prev"]
            dct = grads["dct_prev"]

        return gradients

Putting All Together

The following code implements a complete training process. First, we pass the training data into forward propagation, calculate the loss, and then pass it into backward propagation to finally get the gradients. To prevent exploding gradients from happening, we will clip the gradients. Then, use it to update the parameters. This is a complete training.

class LSTM:
    def optimize(self, X, Y, a_prev, parameters, learning_rate, clip_value):
        """
        Implements the forward and backward propagation of the LSTM.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) - input data for each time step
        Y: (ndarray (n_y, m, T_x)) - true labels for each time step
        a_prev: (ndarray (n_a, m)) - initial hidden state
        parameters: (dict) - dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) - bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the update gate
            bi: (ndarray (n_a, 1)) - bias of the update gate
            Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) - bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate
            bo: (ndarray (n_a, 1)) - bias of the output gate
            Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output
        learning_rate: (float) - learning rate for the optimization algorithm
        clip_value: (float) - maximum value for the gradients

        Returns
        -------
        at: (ndarray (n_a, m)) hidden state for the last time step
        loss: (float) - the cross-entropy
        """

        A, C, Y_hat, caches = self.forward(X, a_prev, parameters)
        loss = self.compute_loss(Y_hat, Y)
        gradients = self.backward(X, Y, parameters, caches)
        gradients = self.clip(gradients, clip_value)
        self.update_parameters(parameters, gradients, learning_rate)

        at = A[:, :, -1]
        return at, loss

Example

Next, we design LSTM as a character-level language model. The training material is a passage from Shakespeare. It trains one character at a time, so the sequence length T_x will be the length of the input character, and uses one-hot encoding to encode each character. Please refer to the following article for details of this part, because this article and the following article use the same example.

An example of using this LSTM is as follows:

if __name__ == "__main__":
    with open("shakespeare.txt", "r") as file:
        text = file.read()

    chars = sorted(list(set(text)))
    vocab_size = len(chars)

    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for i, ch in enumerate(chars)}

    lstm = LSTM(64, vocab_size, vocab_size)
    losses = lstm.train(text, char_to_idx, num_iterations=100, learning_rate=0.01, clip_value=5)

    generated_text = lstm.sample("T", char_to_idx, idx_to_char, num_chars=100)
    print(generated_text)

Conclusion

Standard RNNs face the problem of vanishing gradients when processing long sequence data. LSTM was developed to address this problem. It has been used in speech recognition and machine translation, with quite good results.

References

  • S. Hochreiter and J. Schmidhuber. 1997. Long Short-term Memory. Neural Computation, 9(8): 1735-1780.
  • Andrew Ng, Deep Learning Specialization, Coursera.
Leave a Reply

Your email address will not be published. Required fields are marked *

You May Also Like