Bi-directional Recurrent Neural Networks (BRNNs)

Photo by Daniele Buso on Unsplash
Photo by Daniele Buso on Unsplash
Bi-directional recurrent neural betworks (BRNNs) are an extension of standard RNNs specifically designed to process sequential data in both forward and backward directions. Compared to traditional RNNs, BRNN architectures maintain more comprehensive context information, enabling them to capture useful dependencies across entire sequences for improved predictions in various natural language processing and speech recognition tasks.


Bi-directional recurrent neural betworks (BRNNs) are an extension of standard RNNs specifically designed to process sequential data in both forward and backward directions. Compared to traditional RNNs, BRNN architectures maintain more comprehensive context information, enabling them to capture useful dependencies across entire sequences for improved predictions in various natural language processing and speech recognition tasks.

The complete code for this chapter can be found in .

BRNN

A traditional RNN only looks at the past, meaning it processes information in one direction. This one-directional flow can sometimes limit the model’s ability to capture important context from the future part of a sequence in real-world tasks. Bi-directional recurrent neural networks (BRNNs) solve this issue by processing the sequence in both forward and backward directions. In doing so, a BRNN is capable of leveraging context from the entire input sequence to make more accurate predictions. 

The following figure shows the architecture of BRNN. BRNN consists of two independent RNNs:

  • Forward RNN: Processes data sequentially from t=1 to t=T.
  • Backward RNN: Processes data sequentially from t=T to t=1.

At each time step, the final hidden state is a combination of the forward and backward hidden states. Finally, this combined hidden state outputs the predicted value through the output layer.

BRNN architecture.
BRNN architecture.

As can be seen, the computation time and memory requirements of BRNN are twice as high as those of unidirectional RNN. Next, we introduce a bidirectional vanilla RNN. As the name suggests, this BRNN is composed of two vanilla RNNs. If you are not familiar with RNN or vanilla RNN, please refer to the following article first.

Forward Propagation

The following figure shows BRNN forward propagation. This BRNN contains two vanilla RNNs, one is the forward direction RNN and the other is the backward direction RNN. The hidden states of these two RNNs are stacked vertically to become the hidden state of the cell.

BRNN Cell Forward.
BRNN Cell Forward.

The formula in BRNN cell is as follows:

z_f^{<t>}=W_{fx}x^{<t>}+W_{fa}a_f^{<t-1>}+b_{fa} \\\\ a_f^{<t>}=tanh(z_f^{<t>}) \\\\ z_b^{<t>}=W_{bx}x^{<t>}+W_{ba}a_b^{<t+1>}+b_{ba} \\\\ a_b^{<t>}=tanh(z_b^{<t>}) \\\\ a^{<t>}=\begin{bmatrix} a_f^{<t>} \\ a_b^{<t>} \end{bmatrix} \\\\ z_y^{<t>}=W_ya^{<t>}+b_y \\\\ \hat{y}^{<t>}=softmax(z_y^{<t>})

The dimensions of the BRNN input X and true labels Y are as follows:

X:(n_x,m,T_x)-\text{the inputs.} \\\\ Y:(n_y,m,T_y)-\text{the true labels.} \\\\ m:\text{the number of examples.} \\\\ n_x:\text{the number of units in }x^{(i)<t>}. \\\\ n_y:\text{the number of units in }y^{(i)<t>}. \\\\ n_a:\text{the number of units in hidden state.} \\\\ x^{(i)}:\text{the input of }i\text{-th example} \\\\ T_x:\text{the input sequence length.} \\\\ T_y:\text{the output sequece length.}

In a BRNN cell, the dimensions of each variable are as follows:

W_{fx}x^{<t>}W_{fa}a_f^{<t-1>}b_{fa}
(n_a,n_x)(n_x,m)(n_a,n_a)(n_a,m)(n_a,1)
W_{bx}x^{<t>}W_{ba}a_b^{<t-1>}b_{ba}
(n_a,n_x)(n_x,m)(n_a,n_a)(n_a,m)(n_a,1)
W_ya^{<t>}b_y\hat{y}^{<t>}
(n_y,n_a*2)(n_a*2,m)(n_y,1)(n_y,m)

The following is the implementation of forward propagation of BRNN.

class BRNN:
    def cell_forward_forward(self, xt, aft_prev, parameters):
        """
        Implement a single forward step for the BRNN-cell forward direction.

        Parameters
        ----------
        xt: (ndarray (n_x, m)) - input at timestep "t"
        aft_prev: (ndarray (n_a, m)) - hidden state at timestep "t-1" in the forward direction
        parameters: (dict) - the parameters
            "Wfx": (ndarray (n_a, n_x)) - weights for the forward input
            "Wfa": (ndarray (n_a, n_a)) - weights for the forward hidden state
            "bfa": (ndarray (n_a, 1)) - bias for the forward hidden state

        Returns
        -------
        aft: (ndarray (n_a, m)) - hidden state at timestep "t" in the forward direction
        zfxt: (ndarray (n_a, m)) - logit at timestep "t" in the forward direction
        """

        Wfx, Wfa, bfa = parameters["Wfx"], parameters["Wfa"], parameters["bfa"]
        zfxt = Wfx @ xt + Wfa @ aft_prev + bfa
        aft = np.tanh(zfxt)
        return aft, zfxt

    def cell_forward_backward(self, xt, abt_next, parameters):
        """
        Implement a single forward step for the BRNN-cell backward direction.

        Parameters
        ----------
        xt: (ndarray (n_x, m)) - input at timestep "t"
        abt_next: (ndarray (n_a, m)) - hidden state at timestep "t+1" in the backward direction
        parameters: (dict) - the parameters
            "Wbx": (ndarray (n_a, n_x)) - weights for the backward input
            "Wba": (ndarray (n_a, n_a)) - weights for the backward hidden state
            "bba": (ndarray (n_a, 1)) - bias for the backward hidden state

        Returns
        -------
        abt: (ndarray (n_a, m)) - hidden state at timestep "t" in the backward direction
        zbxt: (ndarray (n_a, m)) - logit at timestep "t" in the backward direction
        """

        Wbx, Wba, bba = parameters["Wbx"], parameters["Wba"], parameters["bba"]
        zbxt = Wbx @ xt + Wba @ abt_next + bba
        abt = np.tanh(zbxt)
        return abt, zbxt

    def forward(self, X, af0, abl, parameters):
        """
        Implement the forward propagation for the BRNN.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) - input data
        af0: (ndarray (n_a, m)) - initial hidden state for the forward direction
        abl: (ndarray (n_a, m)) - initial hidden state for the backward direction
        parameters: (dict) - the parameters
            "Wfx": (ndarray (n_a, n_x)) - weights for the forward input
            "Wfa": (ndarray (n_a, n_a)) - weights for the forward hidden state
            "Wbx": (ndarray (n_a, n_x)) - weights for the backward input
            "Wba": (ndarray (n_a, n_a)) - weights for the backward hidden state
            "Wy": (ndarray (n_y, n_a * 2)) - weights for the output
            "bfa": (ndarray (n_a, 1)) - bias for the forward hidden state
            "bba": (ndarray (n_a, 1)) - bias for the backward hidden state
            "by": (ndarray (n_y, 1)) - bias for the output

        Returns
        -------
        AF: (ndarray (n_a, m, T_x)) - hidden states for the forward direction
        AB: (ndarray (n_a, m, T_x)) - hidden states for the backward direction
        Y_hat: (ndarray (n_y, m, T_x)) - predictions for each timestep
        caches: (list) - caches for each timestep
        """

        n_x, m, T_x = X.shape
        n_y, n_a_2 = parameters["Wy"].shape
        n_a = n_a_2 // 2

        fcaches = []
        AF = np.zeros((n_a, m, T_x))
        aft_prev = af0
        for t in range(T_x):
            xt = X[:, :, t]
            aft, zfxt = self.cell_forward_forward(xt, aft_prev, parameters)
            AF[:, :, t] = aft
            fcaches.append((aft, aft_prev, xt, zfxt))
            aft_prev = aft

        bcaches = []
        AB = np.zeros((n_a, m, T_x))
        abt_next = abl
        for t in reversed(range(T_x)):
            xt = X[:, :, t]
            abt, zbxt = self.cell_forward_backward(xt, abt_next, parameters)
            AB[:, :, t] = abt
            bcaches.insert(0, (abt, abt_next, xt, zbxt))
            abt_next = abt

        caches = []
        Wy, by = parameters["Wy"], parameters["by"]
        Y_hat = np.zeros((n_y, m, T_x))
        for t in range(T_x):
            aft, aft_prev, xt, zfxt = fcaches[t]
            abt, abt_next, xt, zbxt = bcaches[t]
            at = np.concatenate((aft, abt), axis=0)
            zyt = Wy @ at + by
            y_hat_t = softmax(zyt)
            Y_hat[:, :, t] = y_hat_t
            caches.append((fcaches[t], bcaches[t], (at, zyt, y_hat_t)))

        return AF, AB, Y_hat, caches

Loss Function

Since our BRNN is composed of two vanilla RNNs, we also use softmax to output \hat{y}, so we use cross-entropy loss as its loss function. For the formula and implementation of cross-entropy loss, please refer to the following article.

Backward Propagation

The backpropagation of BRNN is a bit more complicated. We need to split \frac{\partial \mathcal{L}^{}}{\partial a^{}} calculated in the output layer into two, one for the forward direction RNN and the other for the backward direction RNN. The backpropagation of the forward direction RNN is calculated from the last time step forward. However, the backpropagation of the backward direction RNN is calculated from the first time step backwards.

BRNN Cell Backward.
BRNN Cell Backward.

The following calculates the partial derivatives in the output layer.

\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}=\hat{y}^{<t>}-y^{<t>} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial W_y}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}\frac{\partial z_y^{<t>}}{\partial W_y}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}a^{<t>T} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial b_y}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}\frac{\partial z_y^{<t>}}{\partial b_y}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial a^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}}\frac{\partial z_y^{<t>}}{\partial a^{<t>}}=W_y^T\frac{\partial \mathcal{L}^{<t>}}{\partial z_y^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial a^{<t>}}=\begin{bmatrix} \frac{\partial \mathcal{L}^{<t>}}{\partial a_f^{<t>}} \\ \frac{\partial \mathcal{L}^{<t>}}{\partial a_b^{<t>}} \end{bmatrix}

The following calculates the partial derivatives of all parameters in the forward direction RNN.

\frac{\partial \mathcal{L}^{<t>}}{\partial a_f^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a_f^{<t>}}+\frac{\partial \mathcal{L}^{<t+1>}}{\partial a_f^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a_f^{<t>}}\frac{\partial a_f^{<t>}}{\partial z_f^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a_f^{<t>}}\cdot(1-(a_f^{<t>})^2) \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial W_{fx}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}}\frac{\partial z_f^{<t>}}{\partial W_{fx}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}}x^{<t>T} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial W_{fa}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}}\frac{\partial z_f^{<t>}}{\partial W_{fa}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}}a_f^{<t-1>T} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial b_{fa}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}}\frac{\partial z_f^{<t>}}{\partial b_{fa}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial a_f^{<t-1>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}}\frac{\partial z_f^{<t>}}{\partial a_f^{<t-1>}}=W_{fa}^T\frac{\partial \mathcal{L}^{<t>}}{\partial z_f^{<t>}}

The following calculates the partial derivatives of all parameters in the backward direction RNN.

\frac{\partial \mathcal{L}^{<t>}}{\partial a_b^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a_b^{<t>}}+\frac{\partial \mathcal{L}^{<t-1>}}{\partial a_b^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a_b^{<t>}}\frac{\partial a_b^{<t>}}{\partial z_b^{<t>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial a_b^{<t>}}\cdot(1-(a_b^{<t>})^2) \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial W_{bx}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}}\frac{\partial z_b^{<t>}}{\partial W_{bx}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}}x^{<t>T} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial W_{ba}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}}\frac{\partial z_b^{<t>}}{\partial W_{ba}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}}a_b^{<t+1>T} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial b_{ba}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}}\frac{\partial z_b^{<t>}}{\partial b_{ba}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}} \\\\ \frac{\partial \mathcal{L}^{<t>}}{\partial a_b^{<t+1>}}=\frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}}\frac{\partial z_b^{<t>}}{\partial a_b^{<t+1>}}=W_{ba}^T\frac{\partial \mathcal{L}^{<t>}}{\partial z_b^{<t>}}

The above is how to find all partial derivatives at each time step. Finally, we have to sum up all the partial derivatives we have taken.

\displaystyle \frac{\partial \mathcal{L}}{\partial W_y}=\sum_{t=1}^T\frac{\partial \mathcal{L}^{<t>}}{\partial W_y},\frac{\partial \mathcal{L}}{\partial b_y}=\sum_{t=1}^T\frac{\partial \mathcal{L}^{<t>}}{\partial b_y} \\\\ \frac{\partial \mathcal{L}}{\partial W_{fx}}=\sum_{t=1}^T\frac{\partial \mathcal{L}^{<t>}}{\partial W_{fx}},\frac{\partial \mathcal{L}}{\partial W_{fa}}=\sum_{t=1}^T\frac{\partial \mathcal{L}^{<t>}}{\partial W_{fa}},\frac{\partial \mathcal{L}}{\partial b_{fa}}=\sum_{t=1}^T\frac{\partial \mathcal{L}^{<t>}}{\partial b_{fa}} \\\\ \frac{\partial \mathcal{L}}{\partial W_{bx}}=\sum_{t=1}^T\frac{\partial \mathcal{L}^{<t>}}{\partial W_{bx}},\frac{\partial \mathcal{L}}{\partial W_{ba}}=\sum_{t=1}^T\frac{\partial \mathcal{L}^{<t>}}{\partial W_{ba}},\frac{\partial \mathcal{L}}{\partial b_{ba}}=\sum_{t=1}^T\frac{\partial \mathcal{L}^{<t>}}{\partial b_{ba}}

The following is the implementation of backward propagation of BRNN.

class BRNN:
    def cell_backward_forward(self, daft, fcache, parameters):
        """
        Implement a single backward step for the BRNN-cell forward direction.

        Parameters
        ----------
        daft: (ndarray (n_a, m)) - gradient of the hidden state at timestep "t" in the forward direction
        fcache: (tuple) - cache from the forward direction
        parameters: (dict) - the parameters
            "Wfa": (ndarray (n_a, n_a)) - weights for the forward hidden state

        Returns
        -------
        gradients: (dict) - the gradients
            "dWfx": (ndarray (n_a, n_x)) - gradients for the forward input
            "dWfa": (ndarray (n_a, n_a)) - gradients for the forward hidden state
            "dbfa": (ndarray (n_a, 1)) - gradients for the forward hidden state
            "daft": (ndarray (n_a, m)) - gradient of the hidden state at timestep "t-1" in the forward direction
        """

        aft, aft_prev, xt, zfxt = fcache
        dfz = (1 - aft ** 2) * daft
        gradients = {
            "dbfa": np.sum(dfz, axis=1, keepdims=True),
            "dWfx": dfz @ xt.T,
            "dWfa": dfz @ aft_prev.T,
            "daft": parameters["Wfa"].T @ dfz,
        }
        return gradients

    def cell_backward_backward(self, dabt, bcache, parameters):
        """
        Implement a single backward step for the BRNN-cell backward direction.

        Parameters
        ----------
        dabt: (ndarray (n_a, m)) - gradient of the hidden state at timestep "t" in the backward direction
        bcache: (tuple) - cache from the backward direction
        parameters: (dict) - the parameters
            "Wba": (ndarray (n_a, n_a)) - weights for the backward hidden state

        Returns
        -------
        gradients: (dict) - the gradients
            "dWbx": (ndarray (n_a, n_x)) - gradients for the backward input
            "dWba": (ndarray (n_a, n_a)) - gradients for the backward hidden state
            "dbba": (ndarray (n_a, 1)) - gradients for the backward hidden state
            "dabt": (ndarray (n_a, m)) - gradient of the hidden state at timestep "t+1" in the backward direction
        """

        abt, abt_next, xt, zbxt = bcache
        dbz = (1 - abt ** 2) * dabt
        gradients = {
            "dbba": np.sum(dbz, axis=1, keepdims=True),
            "dWbx": dbz @ xt.T,
            "dWba": dbz @ abt_next.T,
            "dabt": parameters["Wba"].T @ dbz,
        }
        return gradients

    def backward(self, X, Y, parameters, Y_hat, caches):
        """
        Implement the backward propagation for the BRNN.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) - input data
        Y: (ndarray (n_y, m, T_x)) - true labels
        parameters: (dict) - the parameters
            "Wfx": (ndarray (n_a, n_x)) - weights for the forward input
            "Wfa": (ndarray (n_a, n_a)) - weights for the forward hidden state
            "Wbx": (ndarray (n_a, n_x)) - weights for the backward input
            "Wba": (ndarray (n_a, n_a)) - weights for the backward hidden state
            "Wy": (ndarray (n_y, n_a * 2)) - weights for the output
            "bfa": (ndarray (n_a, 1)) - bias for the forward hidden state
            "bba": (ndarray (n_a, 1)) - bias for the backward hidden state
            "by": (ndarray (n_y, 1)) - bias for the output
        """

        n_x, m, T_x = X.shape
        n_a, n_x = parameters["Wfx"].shape

        Wfx, Wfa, bfa = parameters["Wfx"], parameters["Wfa"], parameters["bfa"]
        Wbx, Wba, bba = parameters["Wbx"], parameters["Wba"], parameters["bba"]
        Wy, by = parameters["Wy"], parameters["by"]

        gradients = {
            "dWfx": np.zeros_like(Wfx), "dWfa": np.zeros_like(Wfa), "dbfa": np.zeros_like(bfa),
            "dWbx": np.zeros_like(Wbx), "dWba": np.zeros_like(Wba), "dbba": np.zeros_like(bba),
            "dWy": np.zeros_like(Wy), "dby": np.zeros_like(by),
        }

        daf = np.zeros((n_a, m, T_x))
        dab = np.zeros((n_a, m, T_x))
        for t in range(T_x):
            _, _, (at, zyt, y_hat_t) = caches[t]
            dy = Y_hat[:, :, t] - Y[:, :, t]
            dWy = dy @ at.T
            dby = np.sum(dy, axis=1, keepdims=True)
            gradients["dWy"] += dWy
            gradients["dby"] += dby
            dat = Wy.T @ dy
            daf[:, :, t] = dat[:n_a, :]
            dab[:, :, t] = dat[n_a:, :]

        daft = np.zeros((n_a, m))
        for t in reversed(range(T_x)):
            fcaches, _, _ = caches[t]
            grads = self.cell_backward_forward(daf[:, :, t] + daft, fcaches, parameters)
            gradients["dWfx"] += grads["dWfx"]
            gradients["dWfa"] += grads["dWfa"]
            gradients["dbfa"] += grads["dbfa"]
            daft = grads["daft"]

        dabt = np.zeros_like(daft)
        for t in range(T_x):
            _, bcache, _ = caches[t]
            grads = self.cell_backward_backward(dab[:, :, t] + dabt, bcache, parameters)
            gradients["dWbx"] += grads["dWbx"]
            gradients["dWba"] += grads["dWba"]
            gradients["dbba"] += grads["dbba"]
            dabt = grads["dabt"]

        return gradients

Putting All Together

The following code implements a complete training process. First, we pass the training data into forward propagation, calculate the loss, and then pass it into backward propagation to finally get the gradients. To prevent exploding gradients from happening, we will clip the gradients. Then, use it to update the parameters. This is a complete training.

class BRNN:
    def optimize(self, X, Y, af_prev, ab_last, parameters, learning_rate, clip_value):
        """
        Implement the forward and backward propagation for the BRNN.

        Parameters
        ----------
        X: (ndarray (n_x, m, T_x)) - input data
        Y: (ndarray (n_y, m, T_x)) - true labels
        af_prev: (ndarray (n_a, m)) - initial hidden state for the forward direction
        ab_last: (ndarray (n_a, m)) - initial hidden state for the backward direction
        parameters: (dict) - the parameters
            "Wfx": (ndarray (n_a, n_x)) - weights for the forward input
            "Wfa": (ndarray (n_a, n_a)) - weights for the forward hidden state
            "Wbx": (ndarray (n_a, n_x)) - weights for the backward input
            "Wba": (ndarray (n_a, n_a)) - weights for the backward hidden state
            "Wy": (ndarray (n_y, n_a * 2)) - weights for the output
            "bfa": (ndarray (n_a, 1)) - bias for the forward hidden state
            "bba": (ndarray (n_a, 1)) - bias for the backward hidden state
            "by": (ndarray (n_y, 1)) - bias for the output
        learning_rate: (float) - learning rate
        clip_value: (float) - the maximum value to clip the gradients

        Returns
        -------
        af: (ndarray (n_a, m)) - hidden state at the last timestep for the forward direction
        ab: (ndarray (n_a, m)) - hidden state at the first timestep for the backward direction
        loss: (float) - the cross-entropy loss
        """

        AF, AB, Y_hat, caches = self.forward(X, af_prev, ab_last, parameters)
        loss = self.compute_loss(Y_hat, Y)
        gradients = self.backward(X, Y, parameters, Y_hat, caches)
        gradients = self.clip(gradients, clip_value)
        self.update_parameters(parameters, gradients, learning_rate)

        af = AF[:, :, -1]
        ab = AB[:, :, 0]
        return af, ab, loss

Example

The example below is that the input to the BRNN model is a word, and then predicts whether the word is palindrome. Since the outcome is either yes or no palindrome, we only need the last prediction and ignore the rest of the output.

Below is the training function of the BRNN model.

class BRNN:
    def preprocess_input(self, input, char_to_idx):
        """
        Preprocess the input text data.

        Parameters
        ----------
        input: (tuple) - tuple containing the input text data and the true label
        char_to_idx: (dict) - dictionary mapping characters to indices

        Returns
        -------
        X: (ndarray (n_x, 1, T_x)) - input data for each time step
        Y: (ndarray (n_y, 1, T_x)) - true labels for each time step
        """

        word, label = input

        n_x = len(char_to_idx)
        T_x = len(word)
        X = np.zeros((n_x, 1, T_x))
        for t, ch in enumerate(word):
            X[char_to_idx[ch], 0, t] = 1

        Y = np.zeros((2, 1, T_x))
        if label == 1:
            Y[0, 0, :] = 1.0
            Y[1, 0, :] = 0.0
        else:
            Y[0, 0, :] = 0.0
            Y[1, 0, :] = 1.0

        return X, Y

    def train(self, dataset, char_to_idx, num_iterations=100, learning_rate=0.01, clip_value=5):
        """
        Train the RNN model on the given text.

        Parameters
        ----------
        dataset: (list) - list of tuples containing the input text data and the true labels
        char_to_idx: (dict) - dictionary mapping characters to indices
        num_iterations: (int) - number of iterations for the optimization loop
        learning_rate: (float) - learning rate for the optimization algorithm
        clip_value: (float) - maximum value for the gradients

        Returns
        -------
        losses: (list) - cross-entropy loss at each iteration
        """

        losses = []
        for i in range(num_iterations):
            np.random.shuffle(dataset)
            total_loss = 0.0
            af_prev = self.af0
            ab_next = self.abl

            for input in dataset:
                X, Y = self.preprocess_input(input, char_to_idx)
                af_prev, ab_next, loss = self.optimize(
                    X, Y, af_prev, ab_next, self.parameters, learning_rate, clip_value
                )
                total_loss += loss

            avg_loss = total_loss / len(dataset)
            losses.append(avg_loss)
            print(f"Iteration {i}, Loss: {avg_loss}")

        return losses

Finally, the following code shows how to train our model and use it to make predictions.

def get_palindrome_dataset():
    palindromes = ["racecar", "madam", "kayak", "rotator", "noon", "civic"]
    non_palindromes = ["hello", "abcde", "python", "world", "carrot", "banana"]
    dataset = []
    for palindrome in palindromes:
        dataset.append((palindrome, 1))
    for non_palindrome in non_palindromes:
        dataset.append((non_palindrome, 0))
    return dataset


if __name__ == "__main__":
    dataset = get_palindrome_dataset()

    chars = sorted(list(set(''.join([word for word, _ in dataset]))))
    vocab_size = len(chars)

    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for ch, i in enumerate(chars)}

    brnn = BRNN(64, vocab_size, 2)
    losses = brnn.train(dataset, char_to_idx, 300, 0.01, 5.0)

    correct = 0
    for word, label in dataset:
        X, Y = brnn.preprocess_input((word, label), char_to_idx)
        _, _, Y_hat, _ = brnn.forward(X, brnn.af0, brnn.abl, brnn.parameters)
        pred_probs = Y_hat[:, 0, -1]
        predicted_label = 1 if (pred_probs[0] > pred_probs[1]) else 0
        if predicted_label == label:
            correct += 1

    print(f"Accuracy: {correct / len(dataset)}")

Variants

Just as there are multiple variants of standard RNNs (like LSTM and GRU), there are corresponding BRNN versions, such as bi-directional LSTM (BiLSTM) and bi-directional gated recurrent unit (BiGRU). They can capture long-term dependencies in both directions. This approach often yields superior performance in tasks requiring long-range context.

Conclusion

BRNNs provide an effective way to capture both past and future context in sequential data. Whether you’re working with text, speech, or any type of time series, BRNNs can offer significant performance benefits over standard RNN architectures. By incorporating LSTM or GRU units, you further mitigate the vanishing gradient problem and gain the ability to model long-term dependencies in both directions.

Reference

Leave a Reply

Your email address will not be published. Required fields are marked *

You May Also Like