Attention Models

Photo by Anthony Tran on Unsplash
Photo by Anthony Tran on Unsplash
Attention mechanisms is a method in deep learning that lets a model focus on the most relevant parts of its input when producing each piece of its output. Unlike traditional sequence models that often struggle with longer inputs, attention allows models to dynamically focus on different parts of the input sequence when generating each part of the output sequence.

Attention mechanisms is a method in deep learning that lets a model focus on the most relevant parts of its input when producing each piece of its output. Unlike traditional sequence models that often struggle with longer inputs, attention allows models to dynamically focus on different parts of the input sequence when generating each part of the output sequence.

The complete code for this chapter can be found in .

Attention Mechanism

Traditional neural sequence models (like simple encoder-decoders) would read an entire input sequence and then attempt to compress it all into a single vector, or thought vector. This single vector is used to produce each output token. The problem is that a single summary might lose important details about the input. If you don’t know much about sequence model, you can refer to the following article first.

Attention fixes this by allowing the model to look at different parts of the input separately each time it generates a piece of output. Instead of a single summary of the input, it uses a set of weights that say how much each part of the input should matter when predicting the next output. Attention models enable a neural network to weigh the importance of different elements in an input sequence, dynamically adjusting these weights based on context.

Two prominent types of attention mechanisms are Bahdanau Attention and Luong Attention.

Bahdanau Attention

Introduced by D. Bahdanau et al. in 2014, it is therefore called Bahdanau attention. It uses an alignment model to calculate the alignment scores between the previous hidden state of the decoder and the hidden state of each encoder, and then generates a context vector by applying these attention weights.

The figure below is a Bahdanau attention model. We use single layer of bi-directional GRU (BiGRU) in the encoder and single layer of uni-directional GRU in the decoder. In the paper, Bahdanau uses a model very similar to GRU, so we will use GRU directly here. Of course, the encoder and decoder can also use other RNN models, or multi-layers RNN. When the encoder uses a bi-directional RNN, we need to stack its forward and backward hidden states vertically.

Bahdanau Attention.
Bahdanau Attention.

After the encoder outputs all hidden states, we use the following formula to calculate the context vector. For the context vector c_i, first use the alignment model a to calculate the energy e_{ij} of each encoder hidden state h_j at the time step i. The alignment model a scores how well the input around position j and the output at the position i match. This e_{ij} is the alignment score. Then, use softmax to calculate the weight \alpha_{ij} of each encoder hidden state h_j. This weight \alpha_{ij} can determine the importance of each encoder hidden state h_j to the context vector c_i, so multiplying them and then adding them together is the final context vector

c_i=\displaystyle\sum_{j=1}^{T_x}\alpha_{ij}h_j \\\\ \alpha_{ij}=\frac{\exp(e_{ij})}{\sum_{k=1}^{T_x}\exp(e_{ik})} \\\\ e_{ij}=a(s_{i-1},h_j) \\\\ a(s_{i-1},h_j)=v_a^T\tanh(W_as_{i-1}+U_ah_j)

Finally, the calculated context vector c_i, the previous output y_{i-1}, and a decoder hidden state s_{i-1} will be used as the input of the RNN of the current time step.

p(y_i|y_1,\cdots,y_{i-1},X)=g(y_{i-1},s_i,c_i) \\\\ s_i=f(s_{i-1},y_{i-1},c_i)

The alignment model a is a feed-forward neural network, which is jointly trained with all the other components of the system. It will first take s_{i-1}h_j and go through linear transformations and then add them together. Therefore, Bahdanau attention is also called additive attention.

Bahdanau Attention Implementation

Now let’s implement the Bahdanau attention model in the figure above. Below is the implementation of the encoder, which uses a bi-directional GRU (BiGRU). We use a fully connected layer at the end to convert the dimensions of the encoder hidden states to match the expected dimensions of the decoder.

class BahdanauAttentionEncoder(nn.Module):
    def __init__(self, input_dim, embed_dim, encoder_hidden_dim, decoder_hidden_dim):
        super(BahdanauAttentionEncoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.gru = nn.GRU(embed_dim, encoder_hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)

    def forward(self, src):
        """
        Args
            src: (batch_size, src_len)

        Returns
            all_hidden: (batch_size, src_len, encoder_hidden_dim * 2)
            decoder_hidden_init: (batch_size, decoder_hidden_dim)
        """

        embedded = self.embedding(src)  # (batch_size, src_len, embed_dim)
        # all_hidden: (batch_size, src_len, hidden_dim * 2)
        # final_hidden_state: (num_layers * 2, batch_size, hidden_dim)
        all_hidden, final_hidden = self.gru(embedded)

        # Map to decoder's initial hidden state
        final_forward_hidden = final_hidden[-2, :, :]
        final_backward_hidden = final_hidden[-1, :, :]
        hidden_cat = torch.cat((final_forward_hidden, final_backward_hidden), dim=1)
        decoder_hidden_init = self.fc(hidden_cat)  # (batch_size, decoder_hidden_dim)

        return all_hidden, decoder_hidden_init

The following implements the calculation of weight \alpha_i.

class BahdanauAttention(nn.Module):
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim, attention_dim):
        super(BahdanauAttention, self).__init__()
        self.W = nn.Linear(decoder_hidden_dim, attention_dim)
        self.U = nn.Linear(encoder_hidden_dim * 2, attention_dim)
        self.v = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, decoder_hidden, encoder_all_hidden):
        """
        Args
            decoder_hidden: (batch_size, decoder_hidden_dim)
            encoder_all_hidden: (batch_size, src_len, encoder_hidden_dim * 2)

        Returns
            alpha: (batch_size, src_len)
        """

        Ws = self.W(decoder_hidden).unsqueeze(1)  # (batch_size, 1, decoder_hidden_dim)
        Uh = self.U(encoder_all_hidden)  # (batch_size, src_len, attention_dim)

        energy = self.v(torch.tanh(Ws + Uh))  # (batch_size, src_len, 1)
        energy = energy.squeeze(2)  # (batch_size, src_len)

        alpha = torch.softmax(energy, dim=1)  # (batch_size, src_len)

        return alpha

The following is the implementation of the decoder. It uses a uni-directional GRU. After calculating the context vector c_i, it stacks the context vector vertically with the input of course as the input of GRU.

class BahdanauAttentionDecoder(nn.Module):
    def __init__(self, output_dim, embed_dim, decoder_hidden_dim, encoder_hidden_dim, attention_dim):
        super(BahdanauAttentionDecoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.attention = BahdanauAttention(encoder_hidden_dim, decoder_hidden_dim, attention_dim)
        self.gru = nn.GRU(
            embed_dim + encoder_hidden_dim * 2, decoder_hidden_dim, num_layers=1, batch_first=True, bidirectional=False,
        )
        self.fc = nn.Linear(decoder_hidden_dim, output_dim)

    def forward(self, input_token, decoder_hidden, encoder_all_hidden):
        """
        Args
            input_token: (batch_size)
            decoder_hidden: (batch_size, decoder_hidden_dim)
            encoder_all_hidden: (batch_size, src_len, encoder_hidden_dim * 2)

        Returns
            prediction: (batch_size, output_dim)
            final_hidden: (batch_size, decoder_hidden_dim)
        """

        input_token = input_token.unsqueeze(1)  # (batch_size, 1)
        embedded = self.embedding(input_token)  # (batch_size, 1, embed_dim)
        alpha = self.attention(decoder_hidden, encoder_all_hidden)  # (batch_size, src_len)
        alpha = alpha.unsqueeze(1)  # (batch_size, 1, src_len)
        context = torch.bmm(alpha, encoder_all_hidden)  # (batch_size, 1, encoder_hidden_dim * 2)

        rnn_input = torch.cat((embedded, context), dim=2)  # (batch_size, 1, embed_dim + encoder_hidden_dim * 2)

        decoder_hidden = decoder_hidden.unsqueeze(0)  # (1, batch_size, decoder_hidden_dim)
        # all_hidden: (batch_size, 1, decoder_hidden_dim)
        # final_hidden: (1, batch_size, decoder_hidden_dim)
        all_hidden, final_hidden = self.gru(rnn_input, decoder_hidden)

        final_hidden = final_hidden.squeeze(0)  # (batch_size, decoder_hidden_dim)

        prediction = self.fc(all_hidden.squeeze(1))  # (batch_size, output_dim)

        return prediction, final_hidden

The following integrates the encoder, attention, and decoder.

class BahdanauAttentionModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(BahdanauAttentionModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        """
        Args
            src: (batch_size, src_len)
            tgt: (batch_size, tgt_len)
            teacher_forcing_ratio: float - probability to use teacher forcing

        Returns
            outputs: (batch_size, tgt_len, tgt_vocab_size)
        """

        batch_size, tgt_len = tgt.shape
        tgt_vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size)

        encoder_all_hidden, decoder_hidden = self.encoder(src)
        input_token = tgt[:, 0]
        for t in range(1, tgt_len):
            prediction, decoder_hidden = self.decoder(input_token, decoder_hidden, encoder_all_hidden)
            outputs[:, t] = prediction

            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            input_token = tgt[:, t] if teacher_force else prediction.argmax(1)

        return outputs

Luong Attention

Introduced by M. Luong et al. in 2015, it is also called Luong attention. He proposed three methods for calculating alignment scores. These methods mainly simplify the calculation of alignment scores between the hidden states of the encoder and decoder through dot products. In addition, he also mentioned two types of attention, namely global attention and local attention.

The main idea of ​​the global attention model is that when the decoder is calculating the context vector c_i, it takes into account all the encoder hidden states, just like Bahdanau attention. Global attention has a drawback that every output word must attend to all input words. This is computationally expensive and becomes impractical when translating longer sequences. Local attention allows each output word to choose to focus on only a small part of the input word. This article will only introduce global attention.

The figure below is a Luong global attention model. We use a layer of bi-directional LSTM (BiLSTM) in the encoder and a layer of uni-directional LSTM in the decoder. Of course, the encoder and decoder can also use other RNN models, or multi-layers RNN. When the encoder uses a bi-directional RNN, we need to stack its forward and backward hidden states vertically.

Luong Attention.
Luong Attention.

The method for calculating context vector c_i and weight \alpha_{ij} is the same as it is in Bahdanau attention. However, when calculating the alignment scores, Luong attention uses the hidden state of the current decoder.

c_i=\displaystyle\sum_{j=1}^{T_x}\alpha_{ij}h_j \\\\ \alpha_{ij}=\frac{\exp(e_{ij})}{\sum_{k=1}^{T_x}\exp(e_{ik})} \\\\ e_{ij}=score(s_{i},h_j) \\\\ score(s_{i},h_j)=\begin{cases} s_i^Th_j &dot \\ s_i^TW_ah_j &general \\ v_a^T\tanh(W_a\begin{bmatrix}s_i \\ h_j\end{bmatrix}) &concat \end{cases}

Let’s take a look at these three functions:

  • Dot: It performs dot product on the hidden state s_i of the current decoder and the hidden state h_j of each encoder.
  • General: Based on the dot function, it introduces a learnable parameter W_a.
  • Concat: Basically, this is similar to Bahdanau’s alignment model, the difference is that it uses the current decoder’s hidden state s_i and only use multiplication.

After deriving the context vector c_i, it and s_i are directly used to predict the output after some operations, which is different from Bahdanau attention. Bahdanau attention uses the resulting context vector as the input of the RNN in the decoder.

p(y_i|y_1,\cdots,y_{i-1},X)=\text{softmax}(W_s\tilde{s_i}) \\\\ \tilde{s_i}=\tanh(W_c\begin{bmatrix}c_i \\ s_i\end{bmatrix})

Among the three alignment score functions, the dot function is the simplest and most cited. Therefore, Luong attention is also called dot attention.

Luong Attention Implementation

Now let’s implement the Luong attention model in the figure above. Below is the implementation of the encoder, which uses a bi-directional LSTM (BiLSTM). We use a fully connected layer at the end to convert the dimensions of the encoder hidden states to match the expected dimensions of the decoder. This encoder is basically the same as the Bahdanau attention encoder, except that it uses a different RNN internally.

class LuongAttentionEncoder(nn.Module):
    def __init__(self, input_dim, embed_dim, encoder_hidden_dim, decoder_hidden_dim):
        super(LuongAttentionEncoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, encoder_hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)

    def forward(self, src):
        """
        Args
            src: (batch_size, src_len)

        Returns
            all_hidden: (batch_size, src_len, decoder_hidden_dim)
            decoder_hidden_init: (batch_size, decoder_hidden_dim)
            decoder_cell_init: (batch_size, decoder_hidden_dim)
        """

        embedded = self.embedding(src)  # (batch_size, src_len, embed_dim)
        # all_hidden: (batch_size, src_len, encoder_hidden_dim * 2)
        # final_hidden, final_cell: (num_layers * 2, batch_size, encoder_hidden_dim)
        all_hidden, (final_hidden, final_cell) = self.lstm(embedded)

        # Map to decoder's initial hidden and cell states
        final_forward_hidden = final_hidden[-2, :, :]
        final_backward_hidden = final_hidden[-1, :, :]
        final_hidden = torch.cat((final_forward_hidden, final_backward_hidden), dim=1)
        final_forward_cell = final_cell[-2, :, :]
        final_backward_cell = final_cell[-1, :, :]
        final_cell = torch.cat((final_forward_cell, final_backward_cell), dim=1)
        decoder_hidden_init = self.fc(final_hidden)  # (batch_size, decoder_hidden_dim)
        decoder_cell_init = self.fc(final_cell)  # (batch_size, decoder_hidden_dim)

        b, s, d = all_hidden.shape
        all_hidden_2d = all_hidden.view(b * s, d)
        all_hidden_2d = self.fc(all_hidden_2d)
        all_hidden = all_hidden_2d.view(b, s, self.fc.out_features)  # (batch_size, src_len, decoder_hidden_dim)

        return all_hidden, decoder_hidden_init, decoder_cell_init

The following is the implementation of the dot score function.

class LuongDotAttention(nn.Module):
    def __init__(self):
        super(LuongDotAttention, self).__init__()

    def forward(self, decoder_hidden, encoder_all_hidden):
        """
        Args
            decoder_hidden: (batch_size, decoder_hidden_dim)
            encoder_all_hidden: (batch_size, src_len, decoder_hidden_dim)

        Returns
            alpha: (batch_size, src_len)
        """

        decoder_hidden = decoder_hidden.unsqueeze(1)  # (batch_size, 1, decoder_hidden_dim)
        scores = torch.bmm(decoder_hidden, encoder_all_hidden.transpose(1, 2))  # (batch_size, 1, src_len)
        scores = scores.squeeze(1)  # (batch_size, src_len)

        alpha = torch.softmax(scores, dim=1)  # (batch_size, src_len)

        return alpha

The following is the implementation of the general score function.

class LuongGeneralAttention(nn.Module):
    def __init__(self, decoder_hidden_dim):
        super(LuongGeneralAttention, self).__init__()
        self.W_a = nn.Linear(decoder_hidden_dim, decoder_hidden_dim, bias=False)

    def forward(self, decoder_hidden, encoder_all_hidden):
        """
        Args
            decoder_hidden: (batch_size, decoder_hidden_dim)
            encoder_all_hidden: (batch_size, src_len, decoder_hidden_dim)

        Returns
            alpha: (batch_size, src_len)
        """

        encoder_all_hidden_transformed = self.W_a(encoder_all_hidden)  # (batch_size, src_len, decoder_hidden_dim)
        decoder_hidden = decoder_hidden.unsqueeze(1)  # (batch_size, 1, decoder_hidden_dim)
        scores = torch.bmm(decoder_hidden, encoder_all_hidden_transformed.transpose(1, 2))  # (batch_size, 1, src_len)
        scores = scores.squeeze(1)  # (batch_size, src_len)

        alpha = torch.softmax(scores, dim=1)  # (batch_size, src_len)

        return alpha

Following is the implementation of concat score function.

class LuongConcatAttention(nn.Module):
    def __init__(self, decoder_hidden_dim):
        super(LuongConcatAttention, self).__init__()
        self.W_a = nn.Linear(decoder_hidden_dim * 2, decoder_hidden_dim, bias=False)
        self.v = nn.Linear(decoder_hidden_dim, 1, bias=False)

    def forward(self, decoder_hidden, encoder_all_hidden):
        """
        Args
            decoder_hidden: (batch_size, decoder_hidden_dim)
            encoder_all_hidden: (batch_size, src_len, decoder_hidden_dim)

        Returns
            alpha: (batch_size, src_len)
        """

        b, s, d = encoder_all_hidden.shape
        decoder_hidden_expanded = decoder_hidden.unsqueeze(1).expand(-1, s, -1)  # (batch_size, src_len, hidden_dim)
        concat_input = torch.cat((decoder_hidden_expanded, encoder_all_hidden), dim=2)  # (batch_size, src_len, hidden_dim * 2)
        concat_output = torch.tanh(self.W_a(concat_input))  # (batch_size, src_len, hidden_dim)
        scores = self.v(concat_output)  # (batch_size, src_len, 1)
        scores = scores.squeeze(2)  # (batch_size, src_len)

        alpha = torch.softmax(scores, dim=1)  # (batch_size, src_len)

        return alpha

The following is the implementation of the decoder. It uses a uni-directional LSTM. The decoder design is quite different from Bahdanau attention. It first calculates the current hidden state s_i before calculating the context vector c_i.

class LuongAttentionDecoder(nn.Module):
    def __init__(self, attention, output_dim, embed_dim, decoder_hidden_dim):
        super(LuongAttentionDecoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, decoder_hidden_dim, num_layers=1, batch_first=True)
        self.attention = attention
        self.W_c = nn.Linear(decoder_hidden_dim * 2, decoder_hidden_dim)
        self.W_s = nn.Linear(decoder_hidden_dim, output_dim)

    def forward(self, input_token, decoder_hidden, decoder_cell, encoder_all_hidden):
        """
        Args
            input_token: (batch_size)
            decoder_hidden: (batch_size, decoder_hidden_dim)
            decoder_cell: (batch_size, decoder_hidden_dim)
            encoder_all_hidden: (batch_size, src_len, decoder_hidden_dim)

        Returns
            prediction: (batch_size, output_dim)
            final_hidden: (batch_size, hidden_dim)
            final_cell: (batch_size, hidden_dim)
        """

        input_token = input_token.unsqueeze(1)  # (batch_size, 1)
        embedded = self.embedding(input_token)  # (batch_size, 1, embed_dim)

        decoder_hidden = decoder_hidden.unsqueeze(0)  # (1, batch_size, hidden_dim)
        decoder_cell = decoder_cell.unsqueeze(0)  # (1, batch_size, hidden_dim)

        # decoder_all_hidden: (batch_size, 1, hidden_dim)
        # final_hidden, final_cell: (1, batch_size, hidden_dim)
        decoder_all_hidden, (final_hidden, final_cell) = self.lstm(embedded, (decoder_hidden, decoder_cell))

        final_hidden = final_hidden.squeeze(0)  # (batch_size, hidden_dim)
        final_cell = final_cell.squeeze(0)  # (batch_size, hidden_dim)

        alpha = self.attention(final_hidden, encoder_all_hidden)  # (batch_size, src_len)
        alpha = alpha.unsqueeze(1)  # (batch_size, 1, src_len)
        context = torch.bmm(alpha, encoder_all_hidden)  # (batch_size, 1, hidden_dim)
        context = context.squeeze(1)  # (batch_size, hidden_dim)

        hidden_cat = torch.cat((context, final_hidden), dim=1)  # (batch_size, hidden_dim * 2)
        luong_hidden = torch.tanh(self.W_c(hidden_cat))  # (batch_size, hidden_dim)

        prediction = self.W_s(luong_hidden)  # (batch_size, output_dim)

        return prediction, final_hidden, final_cell

The following integrates the encoder, attention, and decoder.

class LuongAttentionModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(LuongAttentionModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        """
        Args
            src: (batch_size, src_len)
            tgt: (batch_size, tgt_len)
            teacher_forcing_ratio: float - probability to use teacher forcing

        Returns
            outputs: (batch_size, tgt_len, tgt_vocab_size)
        """

        batch_size, tgt_len = tgt.shape
        tgt_vocab_size = self.decoder.W_s.out_features
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size)

        encoder_all_hidden, hidden, cell = self.encoder(src)
        input_token = tgt[:, 0]
        for t in range(1, tgt_len):
            output, hidden, cell = self.decoder(input_token, hidden, cell, encoder_all_hidden)
            outputs[:, t] = output

            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            input_token = tgt[:, t] if teacher_force else output.argmax(1)

        return outputs

Example

The following code shows how to use the implementation of Bahdanau attention and Luong attention in this article. When using the Luong attention model, we can also choose the score function to use.

from bahdanau_attention import *
from luong_attention import *

SOS_TOKEN = 0
EOS_TOKEN = 1
PAD_TOKEN = 2

english_sentences = [
    "hello world",
    "good morning",
    "i love you",
    "cat",
    "dog",
    "go home",
]

spanish_sentences = [
    "hola mundo",
    "buenos dias",
    "te amo",
    "gato",
    "perro",
    "ve a casa",
]


def build_vocab(sentences):
    vocab = list(set([word for sentence in sentences for word in sentence.split(" ")]))
    vocab = ["<sos>", "<eos>", "<pad>"] + vocab
    tkn2idx = {tkn: idx for idx, tkn in enumerate(vocab)}
    idx2tkn = {idx: tkn for tkn, idx in tkn2idx.items()}
    return vocab, tkn2idx, idx2tkn


def convert_sentence_to_idx(sentence, tkn2idx):
    sentence_idx = [tkn2idx[tkn] for tkn in sentence.split(" ")]
    sentence_idx.insert(0, SOS_TOKEN)
    sentence_idx.append(EOS_TOKEN)
    return sentence_idx


def pad_sentence(sentence, max_len):
    return sentence + [PAD_TOKEN] * (max_len - len(sentence))


src_vocab, src_tkn2idx, src_idx2tkn = build_vocab(english_sentences)
tgt_vocab, tgt_tkn2idx, tgt_idx2tkn = build_vocab(spanish_sentences)
src_sentences = [convert_sentence_to_idx(sentence, src_tkn2idx) for sentence in english_sentences]
tgt_sentences = [convert_sentence_to_idx(sentence, tgt_tkn2idx) for sentence in spanish_sentences]
src_max_len = max([len(sentence) for sentence in src_sentences])
tgt_max_len = max([len(sentence) for sentence in tgt_sentences])
src_sentences = [pad_sentence(sentence, src_max_len) for sentence in src_sentences]
tgt_sentences = [pad_sentence(sentence, tgt_max_len) for sentence in tgt_sentences]

ENCODER_EMBEDDING_DIM = 16
ENCODER_HIDDEN_DIM = 32

DECODER_EMBEDDING_DIM = 16
DECODER_HIDDEN_DIM = 32

ATTENTION_HIDDEN_DIM = 32

LEARNING_RATE = 0.01
EPOCHS = 50

using = "luong"  # "bahdanau" or "luong"
loung_attention = "concat"  # "general", "concat", or "dot"

if using == "bahdanau":
    encoder = BahdanauAttentionEncoder(len(src_tkn2idx), ENCODER_EMBEDDING_DIM, ENCODER_HIDDEN_DIM, DECODER_HIDDEN_DIM)
    decoder = BahdanauAttentionDecoder(
        len(tgt_tkn2idx), DECODER_EMBEDDING_DIM, DECODER_HIDDEN_DIM, ENCODER_HIDDEN_DIM, ATTENTION_HIDDEN_DIM
    )
    model = BahdanauAttentionModel(encoder, decoder)
else:
    if loung_attention == "general":
        attention = LuongGeneralAttention(DECODER_HIDDEN_DIM)
    elif loung_attention == "concat":
        attention = LuongConcatAttention(DECODER_HIDDEN_DIM)
    else:
        attention = LuongDotAttention()
    encoder = LuongAttentionEncoder(len(src_vocab), ENCODER_EMBEDDING_DIM, ENCODER_HIDDEN_DIM, DECODER_HIDDEN_DIM)
    decoder = LuongAttentionDecoder(attention, len(tgt_vocab), DECODER_EMBEDDING_DIM, DECODER_HIDDEN_DIM)
    model = LuongAttentionModel(encoder, decoder)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)


def train():
    training_pairs = []
    for src, tgt in zip(src_sentences, tgt_sentences):
        src_tensor = torch.tensor(src, dtype=torch.long)
        tgt_tensor = torch.tensor(tgt, dtype=torch.long)
        training_pairs.append((src_tensor, tgt_tensor))

    for epoch in range(EPOCHS):
        total_loss = 0
        for src_tensor, tgt_tensor in training_pairs:
            src_tensor = src_tensor.unsqueeze(0)
            tgt_tensor = tgt_tensor.unsqueeze(0)

            optimizer.zero_grad()
            outputs = model(src_tensor, tgt_tensor)

            outputs_dim = outputs.shape[-1]
            outputs = outputs[:, 1:, :].reshape(-1, outputs_dim)
            tgt = tgt_tensor[:, 1:].reshape(-1)
            loss = criterion(outputs, tgt)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch: {epoch + 1}, Loss: {total_loss}")


def translate_beam_search(sentence, beam_width=3, max_length=10):
    model.eval()

    src_idx = convert_sentence_to_idx(sentence, src_tkn2idx)
    src_idx = pad_sentence(src_idx, src_max_len)
    src_tensor = torch.tensor(src_idx, dtype=torch.long)

    with torch.no_grad():
        src_tensor = src_tensor.unsqueeze(0)
        if using == "bahdanau":
            encoder_all_hidden, decoder_hidden = model.encoder(src_tensor)
        else:
            encoder_all_hidden, decoder_hidden, decoder_cell = model.encoder(src_tensor)

    if using == "bahdanau":
        beam = [([SOS_TOKEN], decoder_hidden, None, 0.0)]
    else:
        beam = [([SOS_TOKEN], decoder_hidden, decoder_cell, 0.0)]

    completed_sentences = []

    for _ in range(max_length):
        new_beam = []
        for tokens, decoder_hidden, decoder_cell, score in beam:
            if tokens[-1] == EOS_TOKEN:
                completed_sentences.append((tokens, score))
                new_beam.append((tokens, decoder_hidden, decoder_cell, score))
                continue

            input_index = torch.tensor([tokens[-1]], dtype=torch.long)
            with torch.no_grad():
                if using == "bahdanau":
                    prediction, decoder_hidden = model.decoder(input_index, decoder_hidden, encoder_all_hidden)
                else:
                    prediction, decoder_hidden, decoder_cell = model.decoder(
                        input_index, decoder_hidden, decoder_cell, encoder_all_hidden
                    )

                log_probs = torch.log_softmax(prediction, dim=1).squeeze(0)

            topk = torch.topk(log_probs, beam_width)
            for tkn_idx, tkn_score in zip(topk.indices.tolist(), topk.values.tolist()):
                new_tokens = tokens + [tkn_idx]
                new_score = score + tkn_score
                new_beam.append((new_tokens, decoder_hidden, decoder_cell, new_score))

        new_beam.sort(key=lambda x: x[3], reverse=True)
        beam = new_beam[:beam_width]

    for tokens, decoder_hidden, decoder_cell, score in beam:
        if tokens[-1] != EOS_TOKEN:
            completed_sentences.append((tokens, score))

    completed_sentences.sort(key=lambda x: x[1], reverse=True)
    best_tokens, best_score = completed_sentences[0]

    if best_tokens[0] == SOS_TOKEN:
        best_tokens = best_tokens[1:]
    if EOS_TOKEN in best_tokens:
        best_tokens = best_tokens[:best_tokens.index(EOS_TOKEN)]

    return " ".join([tgt_idx2tkn[idx] for idx in best_tokens])


def test():
    test_sentences = [
        "hello world",
        "i love you",
        "cat",
        "go home",
    ]
    for sentence in test_sentences:
        translation = translate_beam_search(sentence)
        print(f"src: {sentence}, tgt: {translation}")


if __name__ == '__main__':
    train()
    test()

Conclusion

The alignment model used by Bahdanau is an additive calculation method, so it is computationally intensive but generally more expressive, so it is suitable for situations where complex and subtle alignment relationships need to be captured. The multiplicative method used by Luong is computationally simpler, faster, and more scalable, making it more suitable when fast training and inference are required or when computing resources are limited.

References

Leave a Reply

Your email address will not be published. Required fields are marked *

You May Also Like