長短期記憶(Long Short-Term Memory, LSTM)

Photo by Sinziana Susa on Unsplash
Photo by Sinziana Susa on Unsplash
長短期記憶(Long short-term memory, LSTM)是一種 RNN,專門用來處理序列資料(sequential data)。相較於標準的 RNN,LSTM 網路能夠維持有用的長期依賴(long-term dependencies)關係,以便在當前和未來的時間步驟中做出預測。

長短期記憶(Long short-term memory, LSTM)是一種 RNN,專門用來處理序列資料(sequential data)。相較於標準的 RNN,LSTM 網路能夠維持有用的長期依賴(long-term dependencies)關係,以便在當前和未來的時間步驟中做出預測。

完整程式碼可以在 下載。

LSTM

標準的 RNN 有梯度消失(vanishing gradients)的問題。所以,若序列資料很長時,標準的 RNN 無法有效地學習早期的輸入資料。也就是說,標準的 RNN 在長期記憶的能力是相當微弱的。LSTM 就是被開發來解決長期記憶的問題。

更多關於 vanishing gradients 的細節,請參考以下文章。此外,若還不熟悉 RNN 的話,也請先參考以下文章。

下圖是 LSTM cell。與標準的 RNN 相比複雜了很多。LSTM 除了 hidden state a 之外,還多了 memory cell state c。此外,標準的 RNN 中的 activation function 只有一個 tanh,然而 LSTM 要計算 forget gate \Gamma_f、input gate \Gamma_i、candidate cell state \tilde{c}、以及 output gate \Gamma_o

LSTM.
LSTM.

以下是這些 gates 與 states 代表的意義。

  • Hidden State a^{<t>}:這是當前序列的狀態,被用來做預測。
  • Memory Cell State c^{<t>}:這是 LSTM 的長期記憶(long-term memory)。
  • Forget Gate \Gamma_f^{<t>}:此 gate 決定應丟棄先前 cell state 的哪些資訊。
  • Input Gate \Gamma_i^{<t>}:此 gate 決定當前輸入中的哪些新資訊應加入當前的 cell state。
  • Candidate Cell State \tilde{c}^{<t>}:經由處理當前輸入,產生出此對網路記憶可能有用的新資訊。
  • Output Gate \Gamma_o^{<t>}:此 gate 決定下一個 hidden state。

前向傳播(Forward Propagation)

下圖是 LSTM forward propagation。每一個 gate 與 candidate cell state 都有其相對應的參數 W,b 和 activation functions。值得注意的是,我們會將 a^{<t-1>}x^{<t>} 垂直地堆疊起來。

LSTM Cell Forward.
LSTM Cell Forward.

LSTM cell 中的公式如下:

\gamma_f^{<t>}=W_f[a^{<t-1>},x^{<t>}]+b_f \\\\ \Gamma_f^{<t>}=\sigma(\gamma_f^{<t>}) \\\\ \gamma_i^{<t>}=W_i[a^{<t-1>},x^{<t>}]+b_i \\\\ \Gamma_i^{<t>}=\sigma(\gamma_i^{<t>}) \\\\ p\tilde{c}^{<t>}=W_c[a^{<t-1>},x^{<t>}]+b_c \\\\ \tilde{c}^{<t>}=\tanh(p\tilde{c}^{<t>}) \\\\ \gamma_o^{<t>}=W_o[a^{<t-1>},x^{<t>}]+b_o \\\\ \Gamma_o^{<t>}=\sigma(\gamma_o^{<t>}) \\\\ c^{<t>}=\Gamma_f^{<t>}\odot c^{<t-1>}+\Gamma_i^{<t>}\odot \tilde{c}^{<t>} \\\\ a^{<t>}=\Gamma_o^{<t>}\odot tanh(c^{<t>}) \\\\ z_y^{<t>}=W_ya^{<t>}+b_y \\\\ \hat{y}^{<t>}=softmax(z_y^{<t>})

LSTM 的輸入 X 和 true labels Y 的維度如下:

X:(n_x,m,T_x)-\text{the inputs.} \\\\ Y:(n_y,m,T_y)-\text{the true labels.} \\\\ m:\text{the number of examples.} \\\\ n_x:\text{the number of units in }x^{(i)<t>}. \\\\ n_y:\text{the number of units in }y^{(i)<t>}. \\\\ n_a:\text{the number of units in hidden state.} \\\\ x^{(i)}:\text{the input of }i\text{-th example} \\\\ T_x:\text{the input sequence length.} \\\\ T_y:\text{the output sequece length.}

在 LSTM cell 中,各個變數的維度如下:

a^{<t-1>}(n_a,m)x^{<t>}(n_x,m)
[a^{<t-1>},x^{<t>}](n_a+n_x,n_a)\hat{y}^{<t>}(n_y,m)
W_f(n_a,n_a+n_x)b_f(n_a,1)
W_i(n_a,n_a+n_x)b_i(n_a,1)
W_c(n_a,n_a+n_x)b_c(n_a,1)
W_o(n_a,n_a+n_x)b_o(n_a,1)
W_y(n_y,n_a)b_y(n_y,1)

以下是 LSTM 的 forward propagation 的實作。

class LSTM:
    def cell_forward(self, xt, at_prev, ct_prev, parameters):
        """
        Implement a single forward step of the LSTM-cell.

        Parameters
        ----------
        xt: (ndarray (n_x, m)) - input data at time step "t"
        at_prev: (ndarray (n_a, m)) - hidden state at time step "t-1"
        ct_prev: (ndarray (n_a, m)) - memory state at time step "t-1"
        parameters: (dict) - dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) - bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the input gate
            bi: (ndarray (n_a, 1)) - bias of the input gate
            Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) - bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate
            bo: (ndarray (n_a, 1)) - bias of the output gate
            Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output

        Returns
        -------
        at: (ndarray (n_a, m)) - hidden state at time step "t"
        ct: (ndarray (n_a, m)) - memory state at time step "t"
        y_hat_t: (ndarray (n_y, m)) - prediction at time step "t"
        cache: (tuple) - values needed for the backward pass, contains (at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt)
        """

        Wf, bf = parameters["Wf"], parameters["bf"]  # forget gate weight and biases
        Wi, bi = parameters["Wi"], parameters["bi"]  # input gate weight and biases
        Wc, bc = parameters["Wc"], parameters["bc"]  # candidate value weight and biases
        Wo, bo = parameters["Wo"], parameters["bo"]  # output gate weight and biases
        Wy, by = parameters["Wy"], parameters["by"]  # prediction weight and biases

        concat = np.concatenate((at_prev, xt), axis=0)

        ft = sigmoid(Wf @ concat + bf)  # forget gate
        it = sigmoid(Wi @ concat + bi)  # update gate
        cct = tanh(Wc @ concat + bc)  # candidate value
        ct = ft * ct_prev + it * cct  # memory state
        ot = sigmoid(Wo @ concat + bo)  # output gate
        at = ot * tanh(ct)  # hidden state

        zyt = Wy @ at + by
        y_hat_t = softmax(zyt)
        cache = (at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt)
        return at, ct, y_hat_t, cache

    def forward(self, X, a0, parameters):
        """
        Implement the forward pass of the LSTM network.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) input data for each time step
        a0: (ndarray (n_a, m)) initial hidden state
        parameters: (dict) dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) weight matrix of the update gate
            bi: (ndarray (n_a, 1)) bias of the update gate
            Wc: (ndarray (n_a, n_a + n_x)) weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) weight matrix of the output gate
            bo: (ndarray (n_a, 1)) bias of the output gate
            Wy: (ndarray (n_y, n_a)) weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) bias relating the hidden-state to the output

        Returns
        -------
        A: (ndarray (n_a, m, T_x)) - hidden states for each time step
        C: (ndarray (n_a, m, T_x)) - memory states for each time step
        Y_hat: (ndarray (n_y, m, T_x)) - predictions for each time step
        caches: (list) - values needed for the backward pass
        """

        caches = []

        Wy = parameters["Wy"]
        x_x, m, T_x = X.shape
        n_y, n_a = Wy.shape

        A = np.zeros((n_a, m, T_x))
        C = np.zeros((n_a, m, T_x))
        Y_hat = np.zeros((n_y, m, T_x))

        at_prev = a0
        ct_prev = np.zeros((n_a, m))

        for t in range(T_x):
            at_prev, ct_prev, y_hat_t, cache = self.cell_forward(X[:, :, t], at_prev, ct_prev, parameters)
            A[:, :, t] = at_prev
            C[:, :, t] = ct_prev
            Y_hat[:, :, t] = y_hat_t
            caches.append(cache)

        return A, C, Y_hat, caches

損失函數(Loss Function)

此文章中,我們使用 softmax 來輸出 \hat{y},因此使用 cross-entropy loss 作為它的 loss function。關於 cross-entropy loss 的公式與實作,請參考以下文章。

反向傳播(Backward Propagation)

LSTM 的 backpropagation 有點複雜。我們必須計算每一個參數的偏導數。尤其在計算 \frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}} 時,我們還要考慮它們上一個 timestep 時的值,也就是 backpropagation through time(BPTT)。

LSTM Cell Backward.
LSTM Cell Backward.

以下求取 output layer 裡的偏導數。

\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\hat{y}^{<t>}-y^{<t>} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial Wya}=\frac{\partial z_y^{<t>}}{\partial Wya}\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}a^{<t>T} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial by}=\frac{\partial z_y^{<t>}}{\partial by}\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial a^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a^{<t>}}+\frac{\partial \mathcal{L}^{<t+1>}}{\partial a^{<t>}}=Wya^T\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}+\frac{\partial \mathcal{L}^{<t+1>}}{\partial a^{<t>}}

以下是求取 forget gate、input gate、candidate cell state、以及 output gate 的偏導數。

\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial\Gamma_o^{<t>}}\frac{\partial\Gamma_o^{<t>}}{\partial\gamma_o^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\cdot tanh(c^{<t>})\cdot\Gamma_o^{<t>}\cdot(1-\Gamma_o^{<t>})

\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\tilde{c}^{<t>}}\frac{\partial\tilde{c}^{<t>}}{\partial p\tilde{c}^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\tilde{c}^{<t>}}\frac{\partial\tilde{c}^{<t>}}{\partial p\tilde{c}^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}]\frac{\partial c^{<t>}}{\partial\tilde{c}^{<t>}}\frac{\partial\tilde{c}^{<t>}}{\partial p\tilde{c}^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\cdot\Gamma_o^{<t>}\cdot(1-tanh^2(c^{<t>}))]\cdot\Gamma_i^{<t>}\cdot(1-(\tilde{c}^{<t>})^2)

\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_i^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\Gamma_i^{<t>}}\frac{\partial\Gamma_i^{<t>}}{\partial\gamma_i^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\Gamma_i^{<t>}}\frac{\partial\Gamma_i^{<t>}}{\partial\gamma_i^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_i^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}]\frac{\partial c^{<t>}}{\partial\Gamma_i^{<t>}}\frac{\partial\Gamma_i^{<t>}}{\partial\gamma_i^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_i^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\cdot\Gamma_o^{<t>}\cdot(1-tanh^2(c^{<t>}))]\cdot\tilde{c}^{<t>} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_i^{<t>}}=}\cdot\Gamma_i^{<t>}\cdot(1-\Gamma_i^{<t>})

\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_f^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\Gamma_f^{<t>}}\frac{\partial\Gamma_f^{<t>}}{\partial\gamma_f^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial\Gamma_f^{<t>}}\frac{\partial\Gamma_f^{<t>}}{\partial\gamma_f^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_f^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}]\frac{\partial c^{<t>}}{\partial\Gamma_f^{<t>}}\frac{\partial\Gamma_f^{<t>}}{\partial\gamma_f^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_f^{<t>}}}=[\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\cdot\Gamma_o^{<t>}\cdot(1-tanh^2(c^{<t>}))]\cdot c^{<t-1>} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_f^{<t>}}=}\cdot\Gamma_f^{<t>}\cdot(1-\Gamma_f^{<t>})

以下求取所有參數 W,b 的偏導數。

\frac{\partial\mathcal{L}^{<t>}}{\partial W_o}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o}\frac{\partial\gamma_o}{\partial W_o}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o} \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial b_o}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o}\frac{\partial\gamma_o}{\partial b_o}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial W_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}}\frac{\partial p\tilde{c}}{\partial W_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}} \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial b_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}}\frac{\partial p\tilde{c}}{\partial b_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial W_i}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i}\frac{\partial\gamma_i}{\partial W_i}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i} \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial b_i}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i}\frac{\partial\gamma_i}{\partial b_i}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial W_f}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f}\frac{\partial\gamma_f}{\partial W_f}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f} \begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial b_f}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f}\frac{\partial\gamma_f}{\partial b_f}=\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f}

以下求取剩餘的偏導數。

\text{let } W_o=[W_{o1},W_{o2}] \\\\ \text{let } W_c=[W_{c1},W_{c2}] \\\\ \text{let } W_i=[W_{i1},W_{i2}] \\\\ \text{let } W_f=[W_{f1},W_{f2}] \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t-1>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial c^{<t-1>}}+\frac{\partial\mathcal{L^{<t>}}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial c^{<t>}}\frac{\partial c^{<t>}}{\partial c^{<t-1>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t-1>}}}=\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}\cdot\Gamma_f^{<t>}+\frac{\partial\mathcal{L^{<t>}}}{\partial a^{<t>}}\cdot\Gamma_o^{<t>}\cdot(1-tanh^2(c^{<t>}))\cdot\Gamma_f^{<t>} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t-1>}}=W_{o2}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o^{<t>}}+W_{c2}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_c^{<t>}}+W_{i2}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i^{<t>}}+W_{f2}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f^{<t>}} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial x^{<t>}}=W_{o1}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_o^{<t>}}+W_{c1}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_c^{<t>}}+W_{i1}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_i^{<t>}}+W_{f1}^T\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_f^{<t>}}

以上是在每個 time step 中求取所有偏導數的方式。我們最後還要將所有求取的偏導數加總起來。

\displaystyle \frac{\partial\mathcal{L}}{\partial W_y}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_y},\frac{\partial\mathcal{L}}{\partial b_y}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial b_y} \\\\ \frac{\partial\mathcal{L}}{\partial W_o}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_o},\frac{\partial\mathcal{L}}{\partial b_o}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_o} \\\\ \frac{\partial\mathcal{L}}{\partial W_c}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_c},\frac{\partial\mathcal{L}}{\partial b_c}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_c} \\\\ \frac{\partial\mathcal{L}}{\partial W_i}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_i},\frac{\partial\mathcal{L}}{\partial b_i}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_i} \\\\ \frac{\partial\mathcal{L}}{\partial W_f}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_f},\frac{\partial\mathcal{L}}{\partial b_f}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_f} \\\\ \frac{\partial\mathcal{L}}{\partial a}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}} \\\\ \frac{\partial\mathcal{L}}{\partial c}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial c^{<t>}}

以下是 LSTM 的 backward propagation 的實作。

class LSTM:
    def cell_backward(self, y, dat, dct, cache, parameters):
        """
        Implement the backward pass for the LSTM-cell.

        Parameters
        ----------
        y: (ndarray (n_y, m)) - true labels for time step "t"
        dat: (ndarray (n_a, m)) - hidden state gradient for time step "t"
        dct: (ndarray (n_a, m)) - memory state gradient for time step "t"
        cache: (tuple) - values from the forward pass at time step "t"
        parameters: (dict) dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) - bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the update gate
            bi: (ndarray (n_a, 1)) - bias of the update gate
            Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) - bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate
            bo: (ndarray (n_a, 1)) - bias of the output gate
            Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output

        Returns
        -------
        gradients: (dict) - dictionary containing the following gradients:
            dWf: (ndarray (n_a, n_a + n_x)) gradient of the forget gate weight
            dbf: (ndarray (n_a, 1)) gradient of the forget gate bias
            dWi: (ndarray (n_a, n_a + n_x)) gradient of the update gate weight
            dbi: (ndarray (n_a, 1)) gradient of the update gate bias
            dWc: (ndarray (n_a, n_a + n_x)) gradient of the candidate value weight
            dbc: (ndarray (n_a, 1)) gradient of the candidate value bias
            dWo: (ndarray (n_a, n_a + n_x)) gradient of the output gate weight
            dbo: (ndarray (n_a, 1)) gradient of the output gate bias
            dWy: (ndarray (n_y, n_a)) gradient of the prediction weight
            dby: (ndarray (n_y, 1)) gradient of the prediction bias
            dat_prev: (ndarray (n_a, m)) gradient of the hidden state for time step "t-1"
            dct_prev: (ndarray (n_a, m)) gradient of the memory state for time step "t-1"
            dxt: (ndarray (n_x, m)) gradient of the input data for time step "t"
        """

        at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt = cache
        n_a, m = at.shape

        dzy = y_hat_t - y
        dWy = dzy @ at.T
        dby = np.sum(dzy, axis=1, keepdims=True)

        dat = parameters["Wy"].T @ dzy + dat

        dot = dat * tanh(ct) * ot * (1 - ot)
        dcct = (dct + dat * ot * (1 - tanh(ct) ** 2)) * it * (1 - cct ** 2)
        dit = (dct + dat * ot * (1 - tanh(ct) ** 2)) * cct * it * (1 - it)
        dft = (dct + dat * ot * (1 - tanh(ct) ** 2)) * ct_prev * ft * (1 - ft)

        concat = np.concatenate((at_prev, xt), axis=0)

        dWo = dot @ concat.T
        dbo = np.sum(dot, axis=1, keepdims=True)
        dWc = dcct @ concat.T
        dbc = np.sum(dcct, axis=1, keepdims=True)
        dWi = dit @ concat.T
        dbi = np.sum(dit, axis=1, keepdims=True)
        dWf = dft @ concat.T
        dbf = np.sum(dft, axis=1, keepdims=True)

        dat_prev = (
            parameters["Wo"][:, :n_a].T @ dot + parameters["Wc"][:, :n_a].T @ dcct +
            parameters["Wi"][:, :n_a].T @ dit + parameters["Wf"][:, :n_a].T @ dft
        )
        dct_prev = dct * ft + ot * (1 - tanh(ct) ** 2) * ft * dat
        dxt = (
            parameters["Wo"][:, n_a:].T @ dot + parameters["Wc"][:, n_a:].T @ dcct +
            parameters["Wi"][:, n_a:].T @ dit + parameters["Wf"][:, n_a:].T @ dft
        )

        gradients = {
            "dWo": dWo, "dbo": dbo, "dWc": dWc, "dbc": dbc, "dWi": dWi, "dbi": dbi, "dWf": dWf, "dbf": dbf,
            "dWy": dWy, "dby": dby, "dct_prev": dct_prev, "dat_prev": dat_prev, "dxt": dxt,
        }
        return gradients

    def backward(self, X, Y, parameters, caches):
        """
        Implement the backward pass for the LSTM network.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) input data for each time step
        Y: (ndarray (n_y, m, T_x)) true labels for each time step
        parameters: (dict) dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) weight matrix of the update gate
            bi: (ndarray (n_a, 1)) bias of the update gate
            Wc: (ndarray (n_a, n_a + n_x)) weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) weight matrix of the output gate
            bo: (ndarray (n_a, 1)) bias of the output gate
            Wy: (ndarray (n_y, n_a)) weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) bias relating the hidden-state to the output
        caches: (list) values needed for the backward pass

        Returns
        -------
        gradients: (dict) - dictionary containing the following gradients:
            dWf: (ndarray (n_a, n_a + n_x)) gradient of the forget gate weight
            dbf: (ndarray (n_a, 1)) gradient of the forget gate bias
            dWi: (ndarray (n_a, n_a + n_x)) gradient of the update gate weight
            dbi: (ndarray (n_a, 1)) gradient of the update gate bias
            dWc: (ndarray (n_a, n_a + n_x)) gradient of the candidate value weight
            dbc: (ndarray (n_a, 1)) gradient of the candidate value bias
            dWo: (ndarray (n_a, n_a + n_x)) gradient of the output gate weight
            dbo: (ndarray (n_a, 1)) gradient of the output gate bias
            dWy: (ndarray (n_y, n_a)) gradient of the prediction weight
            dby: (ndarray (n_y, 1)) gradient of the prediction bias
        """

        n_x, m, T_x = X.shape
        a1, c1, a0, c0, f1, i1, cc1, o1, x1, y_hat_1, zy1 = caches[0]
        Wf, Wi, Wc, Wo, Wy = parameters["Wf"], parameters["Wi"], parameters["Wc"], parameters["Wo"], parameters["Wy"]
        bf, bi, bc, bo, by = parameters["bf"], parameters["bi"], parameters["bc"], parameters["bo"], parameters["by"]

        gradients = {
            "dWf": np.zeros_like(Wf), "dWi": np.zeros_like(Wi), "dWc": np.zeros_like(Wc), "dWo": np.zeros_like(Wo),
            "dbf": np.zeros_like(bf), "dbi": np.zeros_like(bi), "dbc": np.zeros_like(bc), "dbo": np.zeros_like(bo),
            "dWy": np.zeros_like(Wy), "dby": np.zeros_like(by),
        }

        dat = np.zeros_like(a0)
        dct = np.zeros_like(c0)
        for t in reversed(range(T_x)):
            grads = self.cell_backward(Y[:, :, t], dat, dct, caches[t], parameters)
            gradients["dWf"] += grads["dWf"]
            gradients["dWi"] += grads["dWi"]
            gradients["dWc"] += grads["dWc"]
            gradients["dWo"] += grads["dWo"]
            gradients["dbf"] += grads["dbf"]
            gradients["dbi"] += grads["dbi"]
            gradients["dbc"] += grads["dbc"]
            gradients["dbo"] += grads["dbo"]
            gradients["dWy"] += grads["dWy"]
            gradients["dby"] += grads["dby"]
            dat = grads["dat_prev"]
            dct = grads["dct_prev"]

        return gradients

整合全部

以下的程式碼實作了一次完整的訓練流程。首先,我們將訓練資料傳入 forward propagation,計算 loss,然後再傳入 backward propagation,最終得到 gradients。為了防止 exploding gradients 的發生,我們將對 gradients 做 clipping。然後,再用它來更新參數。這就是一次完整的訓練。

class LSTM:
    def optimize(self, X, Y, a_prev, parameters, learning_rate, clip_value):
        """
        Implements the forward and backward propagation of the LSTM.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) - input data for each time step
        Y: (ndarray (n_y, m, T_x)) - true labels for each time step
        a_prev: (ndarray (n_a, m)) - initial hidden state
        parameters: (dict) - dictionary containing:
            Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate
            bf: (ndarray (n_a, 1)) - bias of the forget gate
            Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the update gate
            bi: (ndarray (n_a, 1)) - bias of the update gate
            Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value
            bc: (ndarray (n_a, 1)) - bias of the candidate value
            Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate
            bo: (ndarray (n_a, 1)) - bias of the output gate
            Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output
            by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output
        learning_rate: (float) - learning rate for the optimization algorithm
        clip_value: (float) - maximum value for the gradients

        Returns
        -------
        at: (ndarray (n_a, m)) hidden state for the last time step
        loss: (float) - the cross-entropy
        """

        A, C, Y_hat, caches = self.forward(X, a_prev, parameters)
        loss = self.compute_loss(Y_hat, Y)
        gradients = self.backward(X, Y, parameters, caches)
        gradients = self.clip(gradients, clip_value)
        self.update_parameters(parameters, gradients, learning_rate)

        at = A[:, :, -1]
        return at, loss

範例

接下來,我們將 LSTM 設計為一個 character-level language model。訓練資料是一段莎士比亞的文章。它會一次訓練一個字元,所以 sequence length T_x 就會是輸入字元的長度,並且使用 one-hot encoding 來編碼每一個字元。這部分的細節請參考以下文章,因為本文章與已下文章使用相同的範例。

使用此 LSTM 的範例如下:

if __name__ == "__main__":
    with open("shakespeare.txt", "r") as file:
        text = file.read()

    chars = sorted(list(set(text)))
    vocab_size = len(chars)

    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for i, ch in enumerate(chars)}

    lstm = LSTM(64, vocab_size, vocab_size)
    losses = lstm.train(text, char_to_idx, num_iterations=100, learning_rate=0.01, clip_value=5)

    generated_text = lstm.sample("T", char_to_idx, idx_to_char, num_chars=100)
    print(generated_text)

結語

標準的 RNN 在處理長的序列資料時,會面臨 vanishing gradients 的問題。LSTM 就是因應此問題而被開發出來的。它已經被用於語音辨識和機器翻譯,都有相當不錯的成效。

參考

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *

You May Also Like