長短期記憶(Long short-term memory, LSTM)是一種 RNN,專門用來處理序列資料(sequential data)。相較於標準的 RNN,LSTM 網路能夠維持有用的長期依賴(long-term dependencies)關係,以便在當前和未來的時間步驟中做出預測。
Table of Contents
LSTM
標準的 RNN 有梯度消失(vanishing gradients)的問題。所以,若序列資料很長時,標準的 RNN 無法有效地學習早期的輸入資料。也就是說,標準的 RNN 在長期記憶的能力是相當微弱的。LSTM 就是被開發來解決長期記憶的問題。
更多關於 vanishing gradients 的細節,請參考以下文章。此外,若還不熟悉 RNN 的話,也請先參考以下文章。
下圖是 LSTM cell。與標準的 RNN 相比複雜了很多。LSTM 除了 hidden state 之外,還多了 memory cell state
。此外,標準的 RNN 中的 activation function 只有一個
tanh
,然而 LSTM 要計算 forget gate 、input gate
、candidate cell state
、以及 output gate
。
以下是這些 gates 與 states 代表的意義。
- Hidden State
:這是當前序列的狀態,被用來做預測。
- Memory Cell State
:這是 LSTM 的長期記憶(long-term memory)。
- Forget Gate
:此 gate 決定應丟棄先前 cell state 的哪些資訊。
- Input Gate
:此 gate 決定當前輸入中的哪些新資訊應加入當前的 cell state。
- Candidate Cell State
:經由處理當前輸入,產生出此對網路記憶可能有用的新資訊。
- Output Gate
:此 gate 決定下一個 hidden state。
前向傳播(Forward Propagation)
下圖是 LSTM forward propagation。每一個 gate 與 candidate cell state 都有其相對應的參數 和 activation functions。值得注意的是,我們會將
和
垂直地堆疊起來。
LSTM cell 中的公式如下:
LSTM 的輸入 和 true labels
的維度如下:
在 LSTM cell 中,各個變數的維度如下:
以下是 LSTM 的 forward propagation 的實作。
class LSTM: def cell_forward(self, xt, at_prev, ct_prev, parameters): """ Implement a single forward step of the LSTM-cell. Parameters ---------- xt: (ndarray (n_x, m)) - input data at time step "t" at_prev: (ndarray (n_a, m)) - hidden state at time step "t-1" ct_prev: (ndarray (n_a, m)) - memory state at time step "t-1" parameters: (dict) - dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate bf: (ndarray (n_a, 1)) - bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the input gate bi: (ndarray (n_a, 1)) - bias of the input gate Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value bc: (ndarray (n_a, 1)) - bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate bo: (ndarray (n_a, 1)) - bias of the output gate Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output Returns ------- at: (ndarray (n_a, m)) - hidden state at time step "t" ct: (ndarray (n_a, m)) - memory state at time step "t" y_hat_t: (ndarray (n_y, m)) - prediction at time step "t" cache: (tuple) - values needed for the backward pass, contains (at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt) """ Wf, bf = parameters["Wf"], parameters["bf"] # forget gate weight and biases Wi, bi = parameters["Wi"], parameters["bi"] # input gate weight and biases Wc, bc = parameters["Wc"], parameters["bc"] # candidate value weight and biases Wo, bo = parameters["Wo"], parameters["bo"] # output gate weight and biases Wy, by = parameters["Wy"], parameters["by"] # prediction weight and biases concat = np.concatenate((at_prev, xt), axis=0) ft = sigmoid(Wf @ concat + bf) # forget gate it = sigmoid(Wi @ concat + bi) # update gate cct = tanh(Wc @ concat + bc) # candidate value ct = ft * ct_prev + it * cct # memory state ot = sigmoid(Wo @ concat + bo) # output gate at = ot * tanh(ct) # hidden state zyt = Wy @ at + by y_hat_t = softmax(zyt) cache = (at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt) return at, ct, y_hat_t, cache def forward(self, X, a0, parameters): """ Implement the forward pass of the LSTM network. Parameters ---------- X: (ndarray (n_x, m, T_x)) input data for each time step a0: (ndarray (n_a, m)) initial hidden state parameters: (dict) dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) weight matrix of the forget gate bf: (ndarray (n_a, 1)) bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) weight matrix of the update gate bi: (ndarray (n_a, 1)) bias of the update gate Wc: (ndarray (n_a, n_a + n_x)) weight matrix of the candidate value bc: (ndarray (n_a, 1)) bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) weight matrix of the output gate bo: (ndarray (n_a, 1)) bias of the output gate Wy: (ndarray (n_y, n_a)) weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) bias relating the hidden-state to the output Returns ------- A: (ndarray (n_a, m, T_x)) - hidden states for each time step C: (ndarray (n_a, m, T_x)) - memory states for each time step Y_hat: (ndarray (n_y, m, T_x)) - predictions for each time step caches: (list) - values needed for the backward pass """ caches = [] Wy = parameters["Wy"] x_x, m, T_x = X.shape n_y, n_a = Wy.shape A = np.zeros((n_a, m, T_x)) C = np.zeros((n_a, m, T_x)) Y_hat = np.zeros((n_y, m, T_x)) at_prev = a0 ct_prev = np.zeros((n_a, m)) for t in range(T_x): at_prev, ct_prev, y_hat_t, cache = self.cell_forward(X[:, :, t], at_prev, ct_prev, parameters) A[:, :, t] = at_prev C[:, :, t] = ct_prev Y_hat[:, :, t] = y_hat_t caches.append(cache) return A, C, Y_hat, caches
損失函數(Loss Function)
此文章中,我們使用 softmax
來輸出 ,因此使用 cross-entropy loss 作為它的 loss function。關於 cross-entropy loss 的公式與實作,請參考以下文章。
反向傳播(Backward Propagation)
LSTM 的 backpropagation 有點複雜。我們必須計算每一個參數的偏導數。尤其在計算 和
時,我們還要考慮它們上一個 timestep 時的值,也就是 backpropagation through time(BPTT)。
以下求取 output layer 裡的偏導數。
以下是求取 forget gate、input gate、candidate cell state、以及 output gate 的偏導數。
以下求取所有參數 的偏導數。
以下求取剩餘的偏導數。
以上是在每個 time step 中求取所有偏導數的方式。我們最後還要將所有求取的偏導數加總起來。
以下是 LSTM 的 backward propagation 的實作。
class LSTM: def cell_backward(self, y, dat, dct, cache, parameters): """ Implement the backward pass for the LSTM-cell. Parameters ---------- y: (ndarray (n_y, m)) - true labels for time step "t" dat: (ndarray (n_a, m)) - hidden state gradient for time step "t" dct: (ndarray (n_a, m)) - memory state gradient for time step "t" cache: (tuple) - values from the forward pass at time step "t" parameters: (dict) dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate bf: (ndarray (n_a, 1)) - bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the update gate bi: (ndarray (n_a, 1)) - bias of the update gate Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value bc: (ndarray (n_a, 1)) - bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate bo: (ndarray (n_a, 1)) - bias of the output gate Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output Returns ------- gradients: (dict) - dictionary containing the following gradients: dWf: (ndarray (n_a, n_a + n_x)) gradient of the forget gate weight dbf: (ndarray (n_a, 1)) gradient of the forget gate bias dWi: (ndarray (n_a, n_a + n_x)) gradient of the update gate weight dbi: (ndarray (n_a, 1)) gradient of the update gate bias dWc: (ndarray (n_a, n_a + n_x)) gradient of the candidate value weight dbc: (ndarray (n_a, 1)) gradient of the candidate value bias dWo: (ndarray (n_a, n_a + n_x)) gradient of the output gate weight dbo: (ndarray (n_a, 1)) gradient of the output gate bias dWy: (ndarray (n_y, n_a)) gradient of the prediction weight dby: (ndarray (n_y, 1)) gradient of the prediction bias dat_prev: (ndarray (n_a, m)) gradient of the hidden state for time step "t-1" dct_prev: (ndarray (n_a, m)) gradient of the memory state for time step "t-1" dxt: (ndarray (n_x, m)) gradient of the input data for time step "t" """ at, ct, at_prev, ct_prev, ft, it, cct, ot, xt, y_hat_t, zyt = cache n_a, m = at.shape dzy = y_hat_t - y dWy = dzy @ at.T dby = np.sum(dzy, axis=1, keepdims=True) dat = parameters["Wy"].T @ dzy + dat dot = dat * tanh(ct) * ot * (1 - ot) dcct = (dct + dat * ot * (1 - tanh(ct) ** 2)) * it * (1 - cct ** 2) dit = (dct + dat * ot * (1 - tanh(ct) ** 2)) * cct * it * (1 - it) dft = (dct + dat * ot * (1 - tanh(ct) ** 2)) * ct_prev * ft * (1 - ft) concat = np.concatenate((at_prev, xt), axis=0) dWo = dot @ concat.T dbo = np.sum(dot, axis=1, keepdims=True) dWc = dcct @ concat.T dbc = np.sum(dcct, axis=1, keepdims=True) dWi = dit @ concat.T dbi = np.sum(dit, axis=1, keepdims=True) dWf = dft @ concat.T dbf = np.sum(dft, axis=1, keepdims=True) dat_prev = ( parameters["Wo"][:, :n_a].T @ dot + parameters["Wc"][:, :n_a].T @ dcct + parameters["Wi"][:, :n_a].T @ dit + parameters["Wf"][:, :n_a].T @ dft ) dct_prev = dct * ft + ot * (1 - tanh(ct) ** 2) * ft * dat dxt = ( parameters["Wo"][:, n_a:].T @ dot + parameters["Wc"][:, n_a:].T @ dcct + parameters["Wi"][:, n_a:].T @ dit + parameters["Wf"][:, n_a:].T @ dft ) gradients = { "dWo": dWo, "dbo": dbo, "dWc": dWc, "dbc": dbc, "dWi": dWi, "dbi": dbi, "dWf": dWf, "dbf": dbf, "dWy": dWy, "dby": dby, "dct_prev": dct_prev, "dat_prev": dat_prev, "dxt": dxt, } return gradients def backward(self, X, Y, parameters, caches): """ Implement the backward pass for the LSTM network. Parameters ---------- X: (ndarray (n_x, m, T_x)) input data for each time step Y: (ndarray (n_y, m, T_x)) true labels for each time step parameters: (dict) dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) weight matrix of the forget gate bf: (ndarray (n_a, 1)) bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) weight matrix of the update gate bi: (ndarray (n_a, 1)) bias of the update gate Wc: (ndarray (n_a, n_a + n_x)) weight matrix of the candidate value bc: (ndarray (n_a, 1)) bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) weight matrix of the output gate bo: (ndarray (n_a, 1)) bias of the output gate Wy: (ndarray (n_y, n_a)) weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) bias relating the hidden-state to the output caches: (list) values needed for the backward pass Returns ------- gradients: (dict) - dictionary containing the following gradients: dWf: (ndarray (n_a, n_a + n_x)) gradient of the forget gate weight dbf: (ndarray (n_a, 1)) gradient of the forget gate bias dWi: (ndarray (n_a, n_a + n_x)) gradient of the update gate weight dbi: (ndarray (n_a, 1)) gradient of the update gate bias dWc: (ndarray (n_a, n_a + n_x)) gradient of the candidate value weight dbc: (ndarray (n_a, 1)) gradient of the candidate value bias dWo: (ndarray (n_a, n_a + n_x)) gradient of the output gate weight dbo: (ndarray (n_a, 1)) gradient of the output gate bias dWy: (ndarray (n_y, n_a)) gradient of the prediction weight dby: (ndarray (n_y, 1)) gradient of the prediction bias """ n_x, m, T_x = X.shape a1, c1, a0, c0, f1, i1, cc1, o1, x1, y_hat_1, zy1 = caches[0] Wf, Wi, Wc, Wo, Wy = parameters["Wf"], parameters["Wi"], parameters["Wc"], parameters["Wo"], parameters["Wy"] bf, bi, bc, bo, by = parameters["bf"], parameters["bi"], parameters["bc"], parameters["bo"], parameters["by"] gradients = { "dWf": np.zeros_like(Wf), "dWi": np.zeros_like(Wi), "dWc": np.zeros_like(Wc), "dWo": np.zeros_like(Wo), "dbf": np.zeros_like(bf), "dbi": np.zeros_like(bi), "dbc": np.zeros_like(bc), "dbo": np.zeros_like(bo), "dWy": np.zeros_like(Wy), "dby": np.zeros_like(by), } dat = np.zeros_like(a0) dct = np.zeros_like(c0) for t in reversed(range(T_x)): grads = self.cell_backward(Y[:, :, t], dat, dct, caches[t], parameters) gradients["dWf"] += grads["dWf"] gradients["dWi"] += grads["dWi"] gradients["dWc"] += grads["dWc"] gradients["dWo"] += grads["dWo"] gradients["dbf"] += grads["dbf"] gradients["dbi"] += grads["dbi"] gradients["dbc"] += grads["dbc"] gradients["dbo"] += grads["dbo"] gradients["dWy"] += grads["dWy"] gradients["dby"] += grads["dby"] dat = grads["dat_prev"] dct = grads["dct_prev"] return gradients
整合全部
以下的程式碼實作了一次完整的訓練流程。首先,我們將訓練資料傳入 forward propagation,計算 loss,然後再傳入 backward propagation,最終得到 gradients。為了防止 exploding gradients 的發生,我們將對 gradients 做 clipping。然後,再用它來更新參數。這就是一次完整的訓練。
class LSTM: def optimize(self, X, Y, a_prev, parameters, learning_rate, clip_value): """ Implements the forward and backward propagation of the LSTM. Parameters ---------- X: (ndarray (n_x, m, T_x)) - input data for each time step Y: (ndarray (n_y, m, T_x)) - true labels for each time step a_prev: (ndarray (n_a, m)) - initial hidden state parameters: (dict) - dictionary containing: Wf: (ndarray (n_a, n_a + n_x)) - weight matrix of the forget gate bf: (ndarray (n_a, 1)) - bias of the forget gate Wi: (ndarray (n_a, n_a + n_x)) - weight matrix of the update gate bi: (ndarray (n_a, 1)) - bias of the update gate Wc: (ndarray (n_a, n_a + n_x)) - weight matrix of the candidate value bc: (ndarray (n_a, 1)) - bias of the candidate value Wo: (ndarray (n_a, n_a + n_x)) - weight matrix of the output gate bo: (ndarray (n_a, 1)) - bias of the output gate Wy: (ndarray (n_y, n_a)) - weight matrix relating the hidden-state to the output by: (ndarray (n_y, 1)) - bias relating the hidden-state to the output learning_rate: (float) - learning rate for the optimization algorithm clip_value: (float) - maximum value for the gradients Returns ------- at: (ndarray (n_a, m)) hidden state for the last time step loss: (float) - the cross-entropy """ A, C, Y_hat, caches = self.forward(X, a_prev, parameters) loss = self.compute_loss(Y_hat, Y) gradients = self.backward(X, Y, parameters, caches) gradients = self.clip(gradients, clip_value) self.update_parameters(parameters, gradients, learning_rate) at = A[:, :, -1] return at, loss
範例
接下來,我們將 LSTM 設計為一個 character-level language model。訓練資料是一段莎士比亞的文章。它會一次訓練一個字元,所以 sequence length 就會是輸入字元的長度,並且使用 one-hot encoding 來編碼每一個字元。這部分的細節請參考以下文章,因為本文章與已下文章使用相同的範例。
使用此 LSTM 的範例如下:
if __name__ == "__main__": with open("shakespeare.txt", "r") as file: text = file.read() chars = sorted(list(set(text))) vocab_size = len(chars) char_to_idx = {ch: i for i, ch in enumerate(chars)} idx_to_char = {i: ch for i, ch in enumerate(chars)} lstm = LSTM(64, vocab_size, vocab_size) losses = lstm.train(text, char_to_idx, num_iterations=100, learning_rate=0.01, clip_value=5) generated_text = lstm.sample("T", char_to_idx, idx_to_char, num_chars=100) print(generated_text)
結語
標準的 RNN 在處理長的序列資料時,會面臨 vanishing gradients 的問題。LSTM 就是因應此問題而被開發出來的。它已經被用於語音辨識和機器翻譯,都有相當不錯的成效。
參考
- Andrew Ng, Deep Learning Specialization, Coursera.