門控循環單元(Gated Recurrent Unit, GRU)

Photo by Armand Khoury on Unsplash
Photo by Armand Khoury on Unsplash
門控循環單元(Gate recurrent unit, GRU)是一種 RNN,專門用來處理序列資料(sequential data)。與長短期記憶網路(long short-term memory)相似,它的設計目的是解決標準 RNN 長期依賴(long-term dependency)問題。

門控循環單元(Gate recurrent unit, GRU)是一種 RNN,專門用來處理序列資料(sequential data)。與長短期記憶網路(long short-term memory)相似,它的設計目的是解決標準 RNN 長期依賴(long-term dependency)問題。

完整程式碼可以在 下載。

GRU

標準的 RNN 有梯度消失(vanishing gradients)的問題。所以,若序列資料很長時,標準的 RNN 無法有效地學習早期的輸入資料。也就是說,標準的 RNN 在長期記憶的能力是相當微弱的。更多關於 vanishing gradients 的細節,請參考以下文章。此外,若還不熟悉 RNN 的話,也請先參考以下文章。

而與 LSTM 相比,它的結構更加簡單,計算更有效率。關於 LSTM 的細節,請參考以下文章。

下圖是 GRU cell。與標準的 RNN 相比複雜些,但也比 LSTM 簡單些。GRU 除了 hidden state a 之外,還要計算 reset gate \Gamma_r、update gate \Gamma_u、和 candidate hidden state \tilde{c}

GRU.
GRU.

以下是它們各自代表的意義。

  • Reset gate \Gamma_r:此 gate 決定是否遺忘之前的資訊。
  • Update gate \Gamma_u:此 gate 決定要記住多少之前的資訊。
  • Candidate hidden state \tilde{c}:經由處理當前的輸入與部分過去的狀態,產生出新的候選狀態。

前向傳播(Forward Propagation)

下圖是 GRU forward propagation。每一個 gate 與 candidate cell state 都有其相對應的參數 W,b 和 activation functions。值得注意的是,我們會將 a^{<t-1>}x^{<t>} 垂直地堆疊起來。

GRU Cell Forward.
GRU Cell Forward.

GRU cell 中的公式如下:

\gamma_r^{<t>}=W_r[a^{<t-1>},x^{<t>}]+b_r \\\\ \Gamma_r=\sigma(\gamma_r^{<t>}) \\\\ \gamma_u^{<t>}=W_u[a^{<t-1>},x^{<t>}]+b_u \\\\ \Gamma_u=\sigma(\gamma_u^{<t>}) \\\\ p\tilde{c}^{<t>}=W_c[(\Gamma_r\odot a^{<t-1>}),x^{<t>}]+b_c \\\\ \tilde{c}^{<t>}=tanh(p\tilde{c}^{<t>}) \\\\ a^{<t>}=\Gamma_u\odot\tilde{c}^{<t>}+(1-\Gamma_u)\odot a^{<t-1>} \\\\ z_y^{<t>}=W_ya^{<t>}+b_y \\\\ \hat{y}^{<t>}=softmax(z_y^{<t>})

GRU 的輸入 X 和 true labels Y 的維度如下:

X:(n_x,m,T_x)-\text{the inputs.} \\\\ Y:(n_y,m,T_y)-\text{the true labels.} \\\\ m:\text{the number of examples.} \\\\ n_x:\text{the number of units in }x^{(i)<t>}. \\\\ n_y:\text{the number of units in }y^{(i)<t>}. \\\\ n_a:\text{the number of units in hidden state.} \\\\ x^{(i)}:\text{the input of }i\text{-th example} \\\\ T_x:\text{the input sequence length.} \\\\ T_y:\text{the output sequece length.}

在 GRU cell 中,各個變數的維度如下:

a^{<t-1>}(n_a,m)x^{<t>}(n_x,m)
[a^{<t-1>},x^{<t>}](n_a+n_x,n_a)\hat{y}^{<t>}(n_y,m)
W_r(n_a,n_a+n_x)b_r(n_a,1)
W_c(n_a,n_a+n_x)b_c(n_a,1)
W_u(n_a,n_a+n_x)b_u(n_a,1)
W_y(n_y,n_a)b_y(n_y,1)

以下是 GRU 的 forward propagation 的實作。

class GRU:
    def cell_forward(self, xt, at_prev, parameters):
        """
        Implements a single forward step for the GRU-cell.

        Parameters
        ----------
        xt: (ndarray (n_x, m)) - input data for the current timestep
        at_prev: (ndarray (n_a, m)) - hidden state from the previous timestep
        parameters: (dict) - dictionary containing the weights and biases of the GRU network
            Wu: (ndarray (n_a, n_a + n_x)) - weights of the update gate
            bu: (ndarray (n_a, 1)) - biases of the update gate
            Wr: (ndarray (n_a, n_a + n_x)) - weights of the reset gate
            br: (ndarray (n_a, 1)) - biases of the reset gate
            Wc: (ndarray (n_a, n_a + n_x)) - weights of the candidate value
            bc: (ndarray (n_a, 1)) - biases of the candidate value
            Wy: (ndarray (n_y, n_a)) - weights of the output layer
            by: (ndarray (n_y, 1)) - biases of the output layer

        Returns
        -------
        at: (ndarray (n_a, m)) - hidden state for the current timestep
        y_hat_t: (ndarray (n_y, m)) - prediction for the current timestep
        cache: (tuple) - values needed for the backward pass
        """

        Wu, bu = parameters["Wu"], parameters["bu"]  # update gate weights and biases
        Wr, br = parameters["Wr"], parameters["br"]  # reset gate weights and biases
        Wc, bc = parameters["Wc"], parameters["bc"]  # candidate value weights and biases
        Wy, by = parameters["Wy"], parameters["by"]  # prediction weights and biases

        concat = np.concatenate((at_prev, xt), axis=0)

        ut = sigmoid(Wu @ concat + bu)  # update gate
        rt = sigmoid(Wr @ concat + br)  # reset gate
        cct = tanh(Wc @ np.concatenate((rt * at_prev, xt), axis=0) + bc)  # candidate value
        at = ut * cct + (1 - ut) * at_prev  # hidden state

        zyt = Wy @ at + by
        y_hat_t = softmax(zyt)
        cache = (at, at_prev, ut, rt, cct, xt, y_hat_t, zyt)
        return at, y_hat_t, cache

    def forward(self, X, a0, parameters):
        """
        Implements the forward pass of the GRU network.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) - input data for each time step
        a0: (ndarray (n_a, m)) - initial hidden state
        parameters: (dict) - dictionary containing the weights and biases of the GRU network
            Wu: (ndarray (n_a, n_a + n_x)) - weights of the update gate
            bu: (ndarray (n_a, 1)) - biases of the update gate
            Wr: (ndarray (n_a, n_a + n_x)) - weights of the reset gate
            br: (ndarray (n_a, 1)) - biases of the reset gate
            Wc: (ndarray (n_a, n_a + n_x)) - weights of the candidate value
            bc: (ndarray (n_a, 1)) - biases of the candidate value
            Wy: (ndarray (n_y, n_a)) - weights of the output layer
            by: (ndarray (n_y, 1)) - biases of the output layer

        Returns
        -------
        A: (ndarray (n_a, m, T_x)) - hidden states for each timestep
        Y_hat: (ndarray (n_y, m, T_x)) - predictions for each timestep
        caches: (list) - values needed for the backward pass
        """

        caches = []

        Wy = parameters["Wy"]
        x_x, m, T_x = X.shape
        n_y, n_a = Wy.shape

        A = np.zeros((n_a, m, T_x))
        Y_hat = np.zeros((n_y, m, T_x))

        at_prev = a0

        for t in range(T_x):
            at_prev, y_hat_t, cache = self.cell_forward(X[:, :, t], at_prev, parameters)
            A[:, :, t] = at_prev
            Y_hat[:, :, t] = y_hat_t
            caches.append(cache)

        return A, Y_hat, caches

損失函數(Loss Function)

此文章中,我們使用 softmax 來輸出 \hat{y},因此使用 cross-entropy loss 作為它的 loss function。關於 cross-entropy loss 的公式與實作,請參考以下文章。

反向傳播(Backward Propagation)

GRU 的 backpropagation 有點複雜。我們必須計算每一個參數的偏導數。尤其在計算 \frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}} 時,我們還要考慮它上一個 timestep 時的值,也就是 backpropagation through time(BPTT)。

GRU Cell Backward.
GRU Cell Backward.

以下求取 output layer 裡的偏導數。

\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\hat{y}^{<t>}-y^{<t>} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial Wya}=\frac{\partial z_y^{<t>}}{\partial Wya}\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}a^{<t>T} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial by}=\frac{\partial z_y^{<t>}}{\partial by}\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial a^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a^{<t>}}+\frac{\partial \mathcal{L}^{<t+1>}}{\partial a^{<t>}}=Wya^T\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}+\frac{\partial \mathcal{L}^{<t+1>}}{\partial a^{<t>}}

以下是求取 reset gate、update gate、以及 candidate hidden state 的偏導數。

\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial \tilde{c}^{<t>}}\frac{\partial \tilde{c}^{<t>}}{\partial p\tilde{c}^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\cdot\Gamma_u^{<t>}\cdot(1-(\tilde{c}^{<t>})^2) \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_u^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial \Gamma_u^{<t>}}\frac{\partial \Gamma_u^{<t>}}{\partial \gamma_u^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\cdot(\tilde{c}^{<t>}-a^{<t-1>})\cdot\Gamma_u^{<t>}\cdot(1-\Gamma_u^{<t>}) \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_r^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial \tilde{c}^{<t>}}\frac{\partial \tilde{c}^{<t>}}{\partial p\tilde{c}^{<t>}}\frac{\partial p\tilde{c}^{<t>}}{\partial \Gamma_r^{<t>}}\frac{\partial \Gamma_r^{<t>}}{\partial \gamma_r^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial\gamma_r^{<t>}}}=(W_c^T\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}})\cdot a^{<t-1>}\cdot\Gamma_r^{<t>}\cdot(1-\Gamma_r^{<t>})

以下求取所有參數 W,b 的偏導數。

\frac{\partial\mathcal{L}^{<t>}}{\partial W_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}\begin{bmatrix} \Gamma_r\cdot a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T,\frac{\partial\mathcal{L}^{<t>}}{\partial b_c}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial W_r}=\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_r^{<t>}}\begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T,\frac{\partial\mathcal{L}^{<t>}}{\partial b_r}=\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_r^{<t>}} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial W_u}=\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_u^{<t>}}\begin{bmatrix} a^{<t-1>} \\ x^{<t>} \end{bmatrix}^T,\frac{\partial\mathcal{L}^{<t>}}{\partial b_u}=\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_u^{<t>}}

以下求取剩餘的偏導數。

\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t-1>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\frac{\partial a^{<t>}}{\partial a^{<t-1>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}\frac{\partial p\tilde{c}^{<t>}}{\partial a^{<t-1>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t-1>}}}+\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_r^{<t>}}\frac{\partial \gamma_r^{<t>}}{\partial a^{<t-1>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_u^{<t>}}\frac{\partial \gamma_u^{<t>}}{\partial a^{<t-1>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t-1>}}}=\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}\cdot(1-\Gamma_u^{<t>})+W_c^T\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}\Gamma_r^{<t>} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t-1>}}}+W_r^T\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_r^{<t>}}+W_u^T\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_u^{<t>}} \\\\ \frac{\partial\mathcal{L}^{<t>}}{\partial x^{<t>}}=\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}\frac{\partial p\tilde{c}^{<t>}}{\partial x^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_r^{<t>}}\frac{\partial \gamma_r^{<t>}}{\partial x^{<t>}}+\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_u^{<t>}}\frac{\partial \gamma_u^{<t>}}{\partial x^{<t>}} \\\\ \hphantom{\frac{\partial\mathcal{L}^{<t>}}{\partial x^{<t>}}}=W_c^T\frac{\partial\mathcal{L}^{<t>}}{\partial p\tilde{c}^{<t>}}+W_r^T\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_r^{<t>}}+W_u^T\frac{\partial\mathcal{L}^{<t>}}{\partial \gamma_u^{<t>}}

以上是在每個 time step 中求取所有偏導數的方式。我們最後還要將所有求取的偏導數加總起來。

\displaystyle \frac{\partial\mathcal{L}}{\partial W_y}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_y},\frac{\partial\mathcal{L}}{\partial b_y}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial b_y} \\\\ \frac{\partial\mathcal{L}}{\partial W_r}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_r},\frac{\partial\mathcal{L}}{\partial b_r}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_r} \\\\ \frac{\partial\mathcal{L}}{\partial W_u}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_u},\frac{\partial\mathcal{L}}{\partial b_u}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_u} \\\\ \frac{\partial\mathcal{L}}{\partial W_c}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial W_c},\frac{\partial\mathcal{L}}{\partial b_c}=\sum_{t=1}^{T_y}\frac{\partial \mathcal{L}^{<t>}}{\partial b_c} \\\\ \frac{\partial\mathcal{L}}{\partial a}=\sum_{t=1}^{T_y}\frac{\partial\mathcal{L}^{<t>}}{\partial a^{<t>}}

以下是 GRU 的 backward propagation 的實作。

class GRU:
    def cell_backward(self, y, dat, cache, parameters):
        """
        Implements a single backward step for the GRU-cell.

        Parameters
        ----------
        y: (ndarray (n_y, m)) - true labels for the current timestep
        dat: (ndarray (n_a, m)) - gradient of the hidden state for the current timestep
        cache: (tuple) - values needed for the backward pass
        parameters: (dict) - dictionary containing the weights and biases of the GRU network
            Wu: (ndarray (n_a, n_a + n_x)) - weights of the update gate
            bu: (ndarray (n_a, 1)) - biases of the update gate
            Wr: (ndarray (n_a, n_a + n_x)) - weights of the reset gate
            br: (ndarray (n_a, 1)) - biases of the reset gate
            Wc: (ndarray (n_a, n_a + n_x)) - weights of the candidate value
            bc: (ndarray (n_a, 1)) - biases of the candidate value
            Wy: (ndarray (n_y, n_a)) - weights of the output layer
            by: (ndarray (n_y, 1)) - biases of the output layer

        Returns
        -------
        gradients: (dict) - dictionary containing the gradients of the weights and biases of the GRU network
            dWu: (ndarray (n_a, n_a + n_x)) - gradients of the weights of the update gate
            dbu: (ndarray (n_a, 1)) - gradients of the biases of the update gate
            dWr: (ndarray (n_a, n_a + n_x)) - gradients of the weights of the reset gate
            dbr: (ndarray (n_a, 1)) - gradients of the biases of the reset gate
            dWc: (ndarray (n_a, n_a + n_x)) - gradients of the weights of the candidate value
            dbc: (ndarray (n_a, 1)) - gradients of the biases of the candidate value
            dWy: (ndarray (n_y, n_a)) - gradients of the weights of the output layer
            dby: (ndarray (n_y, 1)) - gradients of the biases of the output
        """

        at, at_prev, ut, rt, cct, xt, y_hat_t, zyt = cache
        n_a, m = at.shape

        dzy = y_hat_t - y
        dWy = dzy @ at.T
        dby = np.sum(dzy, axis=1, keepdims=True)

        dat = parameters["Wy"].T @ dzy + dat

        dcct = dat * ut * (1 - cct ** 2)  # dn_t
        dut = dat * (cct - at_prev) * ut * (1 - ut)
        dat_prev = dat * (1 - ut)

        dcct_ra_x = parameters["Wc"].T @ dcct
        dcct_r_at_prev = dcct_ra_x[:n_a, :]
        dcct_xt = dcct_ra_x[n_a:, :]
        drt = (dcct_r_at_prev * at_prev) * rt * (1 - rt)

        concat = np.concatenate((at_prev, xt), axis=0)

        dWc = dcct @ np.concatenate((rt * at_prev, xt), axis=0).T
        dbc = np.sum(dcct, axis=1, keepdims=True)
        dWr = drt @ concat.T
        dbr = np.sum(drt, axis=1, keepdims=True)
        dWu = dut @ concat.T
        dbu = np.sum(dut, axis=1, keepdims=True)

        dat_prev = (
            dat_prev + dcct_r_at_prev * rt + parameters["Wr"][:, :n_a].T @ drt + parameters["Wu"][:, :n_a].T @ dut
        )
        dxt = (
            dcct_xt + parameters["Wr"][:, n_a:].T @ drt + parameters["Wu"][:, n_a:].T @ dut
        )

        gradients = {
            "dWu": dWu, "dbu": dbu, "dWr": dWr, "dbr": dbr, "dWc": dWc, "dbc": dbc, "dWy": dWy, "dby": dby,
            "dat_prev": dat_prev, "dxt": dxt
        }
        return gradients

    def backward(self, X, Y, parameters, caches):
        """
        Implements the backward pass of the GRU network.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) - input data for each time step
        Y: (ndarray (n_y, m, T_x)) - true labels for each time step
        parameters: (dict) - dictionary containing the weights and biases of the GRU network
            Wu: (ndarray (n_a, n_a + n_x)) - weights of the update gate
            bu: (ndarray (n_a, 1)) - biases of the update gate
            Wr: (ndarray (n_a, n_a + n_x)) - weights of the reset gate
            br: (ndarray (n_a, 1)) - biases of the reset gate
            Wc: (ndarray (n_a, n_a + n_x)) - weights of the candidate value
            bc: (ndarray (n_a, 1)) - biases of the candidate value
            Wy: (ndarray (n_y, n_a)) - weights of the output layer
            by: (ndarray (n_y, 1)) - biases of the output layer
        caches: (list) - values needed for the backward pass

        Returns
        -------
        gradients: (dict) - dictionary containing the gradients of the weights and biases of the GRU network
            dWu: (ndarray (n_a, n_a + n_x)) - gradients of the weights of the update gate
            dbu: (ndarray (n_a, 1)) - gradients of the biases of the update gate
            dWr: (ndarray (n_a, n_a + n_x)) - gradients of the weights of the reset gate
            dbr: (ndarray (n_a, 1)) - gradients of the biases of the reset gate
            dWc: (ndarray (n_a, n_a + n_x)) - gradients of the weights of the candidate value
            dbc: (ndarray (n_a, 1)) - gradients of the biases of the candidate value
            dWy: (ndarray (n_y, n_a)) - gradients of the weights of the output layer
            dby: (ndarray (n_y, 1)) - gradients of the biases of the output layer
        """

        n_x, m, T_x = X.shape
        a1, a0, u0, r1, cc1, x1, y_hat_1, zyt1 = caches[0]
        Wu, Wr, Wc, Wy = parameters["Wu"], parameters["Wr"], parameters["Wc"], parameters["Wy"]
        bu, br, bc, by = parameters["bu"], parameters["br"], parameters["bc"], parameters["by"]

        gradients = {
            "dWu": np.zeros_like(Wu), "dbu": np.zeros_like(bu), "dWr": np.zeros_like(Wr), "dbr": np.zeros_like(br),
            "dWc": np.zeros_like(Wc), "dbc": np.zeros_like(bc), "dWy": np.zeros_like(Wy), "dby": np.zeros_like(by),
        }

        dat = np.zeros_like(a0)
        for t in reversed(range(T_x)):
            grads = self.cell_backward(Y[:, :, t], dat, caches[t], parameters)
            gradients["dWu"] += grads["dWu"]
            gradients["dbu"] += grads["dbu"]
            gradients["dWr"] += grads["dWr"]
            gradients["dbr"] += grads["dbr"]
            gradients["dWc"] += grads["dWc"]
            gradients["dbc"] += grads["dbc"]
            gradients["dWy"] += grads["dWy"]
            gradients["dby"] += grads["dby"]
            dat = grads["dat_prev"]

        return gradients

整合全部

以下的程式碼實作了一次完整的訓練流程。首先,我們將訓練資料傳入 forward propagation,計算 loss,然後再傳入 backward propagation,最終得到 gradients。為了防止 exploding gradients 的發生,我們將對 gradients 做 clipping。然後,再用它來更新參數。這就是一次完整的訓練。

class GRU:
    def optimize(self, X, Y, a_prev, parameters, learning_rate, clip_value):
        """
        Implements the forward and backward pass of the GRU network.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) - input data for each time step
        Y: (ndarray (n_y, m, T_x)) - true labels for each time step
        a_prev: (ndarray (n_a, m)) - initial hidden state
        parameters: (dict) - dictionary containing the weights and biases of the GRU network
            Wu: (ndarray (n_a, n_a + n_x)) - weights of the update gate
            bu: (ndarray (n_a, 1)) - biases of the update gate
            Wr: (ndarray (n_a, n_a + n_x)) - weights of the reset gate
            br: (ndarray (n_a, 1)) - biases of the reset gate
            Wc: (ndarray (n_a, n_a + n_x)) - weights of the candidate value
            bc: (ndarray (n_a, 1)) - biases of the candidate value
            Wy: (ndarray (n_y, n_a)) - weights of the output layer
            by: (ndarray (n_y, 1)) - biases of the output layer
        learning_rate: (float) - learning rate
        clip_value: (float) - maximum value to clip the gradients

        Returns
        -------
        at: (ndarray (n_a, m)) hidden state for the last time step
        loss: (float) - the cross-entropy
        """

        A, Y_hat, caches = self.forward(X, a_prev, parameters)
        loss = self.compute_loss(Y_hat, Y)
        gradients = self.backward(X, Y, parameters, caches)
        gradients = self.clip(gradients, clip_value)
        self.update_parameters(parameters, gradients, learning_rate)

        at = A[:, :, -1]
        return at, loss

範例

接下來,我們將 GRU 設計為一個 character-level language model。訓練資料是一段莎士比亞的文章。它會一次訓練一個字元,所以 sequence length T_x 就會是輸入字元的長度,並且使用 one-hot encoding 來編碼每一個字元。這部分的細節請參考以下文章,因為本文章與已下文章使用相同的範例。

使用此 GRU 的範例如下:

if __name__ == "__main__":
    with open("shakespeare.txt", "r") as file:
        text = file.read()

    chars = sorted(list(set(text)))
    vocab_size = len(chars)

    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for i, ch in enumerate(chars)}

    gru = GRU(64, vocab_size, vocab_size)
    losses = gru.train(text, char_to_idx, num_iterations=100, learning_rate=0.01, clip_value=5)

    generated_text = gru.sample("T", char_to_idx, idx_to_char, num_chars=100)
    print(generated_text)

結語

GRU 比標準 RNN 更能學習長期依賴關係,但比 LSTM 結構更簡單、計算更高效。但是,由於 GRU 的結構較為簡單,因此 LSTM 在學習長期依賴時,表現會更好。因此,我們必須要根據使用的場景和數據的長度,來決定該使用 LSTM 或 GRU。

參考

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *

You May Also Like