雙向 Transformer 編碼器表徵(Bidirectional Encoder Representations from Transformers, BERT)

Photo by Maarten van den Heuvel on Unsplash
Photo by Maarten van den Heuvel on Unsplash
雙向 Transformer 編碼器表徵(Bidirectional Encoder Representations from Transformers, BERT)是由 Google AI 在 2018 年提出的一個用於自然語言處理的預訓練技術。BERT 透過提供對語言更深入的語境理解,顯著推進了自然語言處理的發展。


雙向 Transformer 編碼器表徵(Bidirectional Encoder Representations from Transformers, BERT)
是由 Google AI 在 2018 年提出的一個用於自然語言處理的預訓練技術。BERT 透過提供對語言更深入的語境理解,顯著推進了自然語言處理的發展。

完整程式碼可以在 下載。

BERT 架構

BERT 的全名為 bidirectional encoder representations from Transformers。顧名思義,BERT 的核心架構是 Transformers 的 encoder。如果你還不熟悉 Transformers 的話,請先參考以下文章。

下圖顯示 BERT 與 Transformers 的架構。可以清楚地看出,BERT 是由 Transformers 的 encoder 再加上一個輸出層。BERT 從未標記的 corpus 中預訓練出深度的雙向表徵(deep bidirectional representations),透過在所有 layers 中同時考量左側和右側的上下文。這也就是 Transformers 的 encoder 所做的事情。這個 representations 捕捉了輸入序列中的不同語義層面。Transformers 將這個 representations 傳入 decoder 的 cross multi-head attention 來預設下一個輸出。

然而,pre-trained BERT 模型只是 Transformers 的 encoder,並且只輸出這個 representations。我們可以利用這個輸出的 representations 進行一些 downstream tasks,如問答(question answering)和語言推理(language inference)。我們只需要 pre-trained BERT 模型透過加上一個額外的輸出層,再進行微調(fine-tuning),其可產生這些 downstream tasks 的模型,而無需對特定任務的架構做出大幅度修改。所以,BERT 包含了兩個階段:預訓練(pre-training)和微調(fine-tuning)。

Transformers v.s. BERT (source from Transformers paper).
Transformers v.s. BERT (source from Transformers paper).

輸入與輸出表徵(Input/Output Representations)

為了讓 BERT 能夠處理多種 downstream tasks,因此 BERT 的輸入序列可以是一個句子或是一對句子(如<question,answer>)。BERT 使用由 Google 在 2016 年提出的 WordPiece embeddings,詞彙表大小為 30,000 個 tokens。每一個輸入序列的第一個 token 總是一個特殊的 classification token([CLS])。而此 [CLS] 相對應的 final hidden state 會被用作該序列在 classification task 中的總體表徵(aggregate sequence representation)。

當輸入序列是一對句子時,我們要將兩個句子合併起來,並透過以下兩種方式區分句子:

  • 第一:使用一個特殊 [SEP] token 將它們分開;
  • 第二:為每個 token 添加一個 learned embedding,來標示該 token 屬於句子 A 還是句子 B。

對於給定的一個 token 而言,它的 input representations 是由該 token embedding、segment embedding 以及position embedding 三個 embeddings 相加而得,如下圖所示。

BERT input representation (source from BERT paper).
BERT input representation (source from BERT paper).

BERT 實作

以下是 BERT 模型的實作。如果還不了解 Transformers 或無法理解以下實作的話,請先參考以下文章。

此實作根本就是 Transformers 的 encoder,除了以下兩個地方之外:

  • 在 Embeddings 中,多了一個 token_type_embeddings。這是用來區分句子 A 和句子 B。
  • 在輸出時,多了一個 pool layer。這個 pool layer 擷取 output representations 的第一個 token。該 token 對應於輸入序列中的 [CLS] token。之前有提到,它會在 classification task 中,被作為 aggregate sequence representation。

最後 BERT 模型會輸出 representations 和 aggregate sequence representation。

class Embeddings(nn.Module):
    def __init__(self, vocab_size, token_type_size, max_position_embeddings, hidden_dim, dropout_prob):
        super(Embeddings, self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size, hidden_dim)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_dim)
        self.token_type_embeddings = nn.Embedding(token_type_size, hidden_dim)

        self.norm = nn.LayerNorm(hidden_dim, eps=1e-12)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        """
        Compute the embeddings for the input tokens.

        Args
            x: (batch_size, seq_len)
            token_type_ids: (batch_size, seq_len)

        Returns
            embeddings: (batch_size, seq_len, hidden_dim)
        """

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        seq_len = input_ids.size(1)
        position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (1, seq_len) -> (batch_size, seq_len)

        word_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = word_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, hidden_dim, dropout_prob):
        super(MultiHeadAttention, self).__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"

        self.num_heads = num_heads
        self.head_size = hidden_dim // num_heads
        self.all_head_size = hidden_dim

        self.query = nn.Linear(hidden_dim, self.all_head_size, bias=False)
        self.key = nn.Linear(hidden_dim, self.all_head_size, bias=False)
        self.value = nn.Linear(hidden_dim, self.all_head_size, bias=False)

        self.output = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.norm = nn.LayerNorm(hidden_dim, eps=1e-12)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, hidden_states, mask=None):
        """
        Multi-head attention forward pass.

        Args
            hidden_states: (batch_size, seq_len, hidden_dim)
            mask: (batch_size, 1, 1, seq_len)
                  0 for real tokens, -inf for padding tokens

        Returns
            hidden_states: (batch_size, seq_len, hidden_dim)
        """

        query = self.transpose_for_scores(self.query(hidden_states))  # (batch_size, num_heads, seq_len, head_size)
        key = self.transpose_for_scores(self.key(hidden_states))  # (batch_size, num_heads, seq_len, head_size)
        value = self.transpose_for_scores(self.value(hidden_states))  # (batch_size, num_heads, seq_len, head_size)

        # Scaled dot-product attention
        scores = query @ key.transpose(-2, -1) / math.sqrt(self.head_size)  # (batch_size, num_heads, seq_len, seq_len)
        if mask is not None:
            scores = scores + mask
        attention_weights = F.softmax(scores, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)
        attention_weights = self.dropout(attention_weights)
        attention = attention_weights @ value  # (batch_size, num_heads, seq_len, head_size)

        # Concatenate heads
        attention = attention.transpose(1, 2).contiguous()  # (batch_size, seq_len, num_heads, head_size)
        new_shape = attention.size()[:-2] + (self.all_head_size,)
        attention = attention.view(*new_shape)  # (batch_size, seq_len, all_head_size)

        # Linear projection
        projection_output = self.output(attention)  # (batch_size, seq_len, hidden_dim)
        projection_output = self.dropout(projection_output)

        hidden_states = self.norm(hidden_states + projection_output)
        return hidden_states

    def transpose_for_scores(self, x):
        """
        Args
            x: (batch_size, seq_len, all_head_size)
        Returns
            (batch_size, num_heads, seq_len, head_size)
        """

        new_x_shape = x.size()[:-1] + (self.num_heads, self.head_size)
        x = x.view(*new_x_shape)  # (batch_size, seq_len, num_heads, head_size)
        return x.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, head_size)


class PositionwiseFeedForward(nn.Module):
    def __init__(self, hidden_dim, d_ff):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(hidden_dim, d_ff, bias=True)
        self.linear2 = nn.Linear(d_ff, hidden_dim, bias=True)
        self.activation = nn.GELU()

    def forward(self, hidden_states):
        """
        Feed-forward network forward pass.

        Args
            hidden_states: (batch_size, seq_len, hidden_dim)

        Returns
            hidden_states: (batch_size, seq_len, hidden_dim)
        """

        hidden_states = self.linear2(self.activation(self.linear1(hidden_states)))
        return hidden_states


class EncoderLayer(nn.Module):
    def __init__(self, num_heads, hidden_dim, d_ff, dropout_prob):
        super(EncoderLayer, self).__init__()
        self.multi_head_attention = MultiHeadAttention(num_heads, hidden_dim, dropout_prob)
        self.ffn = PositionwiseFeedForward(hidden_dim, d_ff)
        self.norm = nn.LayerNorm(hidden_dim, eps=1e-12)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, hidden_states, mask=None):
        """
        Encoder layer forward pass.

        Args
            hidden_states: (batch_size, seq_len, hidden_dim)
            mask: (batch_size, 1, seq_len)
                  0 for real tokens, -inf for padding tokens

        Returns
            hidden_states: (batch_size, seq_len, hidden_dim)
        """

        # Multi-head attention
        attention_output = self.multi_head_attention(hidden_states, mask=mask)

        # Feed-forward network
        ffn_output = self.ffn(attention_output)
        ffn_output = self.dropout(ffn_output)
        hidden_states = self.norm(hidden_states + ffn_output)

        return hidden_states


class Encoder(nn.Module):
    def __init__(self, hidden_dim, num_layers, num_heads, d_ff, dropout_prob):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList(
            [EncoderLayer(num_heads, hidden_dim, d_ff, dropout_prob) for _ in range(num_layers)]
        )

    def forward(self, hidden_states, mask=None):
        """
        Encoder forward pass.

        Args
            hidden_states: (batch_size, seq_len, hidden_dim)
            mask: (batch_size, 1, seq_len)
                  0 for real tokens, -inf for padding tokens

        Returns
            hidden_states: (batch_size, seq_len, hidden_dim)
        """

        for layer in self.layers:
            hidden_states = layer(hidden_states, mask=mask)
        return hidden_states


class Pooler(nn.Module):
    def __init__(self, hidden_dim):
        super(Pooler, self).__init__()
        self.linear = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, hidden_states):
        """
        Pooler forward pass.

        Args
            hidden_states: (batch_size, seq_len, hidden_dim)

        Returns
            pooled_output: (batch_size, hidden_dim)
        """

        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.linear(first_token_tensor)
        pooled_output = F.tanh(pooled_output)
        return pooled_output


class Bert(nn.Module):
    def __init__(
        self, vocab_size, token_type_size, max_position_embeddings, hidden_dim, num_layers, num_heads, d_ff,
        dropout_prob
    ):
        super(Bert, self).__init__()
        self.embeddings = Embeddings(vocab_size, token_type_size, max_position_embeddings, hidden_dim, dropout_prob)
        self.encoder = Encoder(hidden_dim, num_layers, num_heads, d_ff, dropout_prob)
        self.pooler = Pooler(hidden_dim)

    def forward(self, input_ids, token_type_ids=None, mask=None):
        """
        Forward pass for the BERT model.

        Args
            input_ids: (batch_size, seq_len)
            token_type_ids: (batch_size, seq_len)
            mask: (batch_size, seq_len)

        Returns
            encoder_output: (batch_size, seq_len, hidden_dim)
            pooled_output: (batch_size, hidden_dim)
        """

        if mask is not None:
            extended_mask = mask.unsqueeze(1).unsqueeze(2)
            extended_mask = extended_mask.to(dtype=torch.float32)
            # Convert 1 -> 0, 0 -> large negative (mask out)
            extended_mask = (1.0 - extended_mask) * -10000.0
        else:
            extended_mask = None

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoder_output = self.encoder(embedding_output, mask=extended_mask)
        pooled_output = self.pooler(encoder_output)

        return encoder_output, pooled_output

預訓練 BERT(Pre-training BERT)

Pre-training BERT 包含了兩個任務,一個是遮罩語言模型(Masked Language Modeling, MLM),另一個是下一句預測(Next Sentence Prediction, NSP)。以下我們會分別介紹這兩個任務的細節。

遮罩語言模型(Masked Language Modeling, MLM)

我們可以合理地推論一個深度的雙向模型(deep bidirectional model)必然比單純的由左至右模型(left-to-right model),或是將左至右與右至左模型做淺層拼接(shallow concatenation)的方式更為強大。然而,傳統的條件式語言模型(conditional language models)只能以左至右或右至左的方式來訓練,因為如果允許雙向條件式建模(bidirectional conditioning),模型會間接「看見」自身要預測的詞彙,從而導致模型可輕易地從多層上下文資訊直接預測目標詞彙。

為了訓練出 deep bidirectional representations,我們直接隨機地將輸入序列中的部分 tokens 進行遮罩(mask),然後讓模型預測那些 masked tokens。此過程稱為遮罩語言模型(Masked LM, MLM)任務。Masked tokens 所對應的 final hidden states 會被送入一個輸出層的 softmax 函數,用以對整個詞彙表(vocabulary)進行預測,類似標準語言模型的做法。

儘管這樣的方式能夠讓我們取得 bidirectional pre-trained model,但其缺點在於 pre-training 階段與 fine-tuning 階段存在一定的差異,因為實際 fine-tuning 時不會出現 [MASK] 這種特殊的 token。為了降低這個問題的影響,我們並不總是將要被遮罩的 token 直接替換為 [MASK] token。

在生成訓練資料時,我們隨機選取 15% 的 token 位置作為預測目標。如果第 i 個 token 被選中,我們會:

  • 以 10% 的機率保持第 i 個 token 不變。
  • 以 80% 的機率將第 i 個 token 替換為 [MASK] token。
  • 以 10% 的機率將第 i 個 token 替換為隨機的 token。

下圖顯示,如何將兩個子句組合起來,並且經由上述的方式來生成給 MLM 用的 training example。

An example of MLM.
An example of MLM.

下一句預測(Next Sentence Prediction, NSP)

許多 downstream tasks,如問 question answering 和 language inference,都仰賴兩個句子之間關係的理解,而這點在傳統語言模型的 pre-training 中並未被直接建模。為了訓練出能理解句子關係的模型,我們還要預訓練下一句預測(Next Sentence Prediction, NSP)任務。

當我們為每筆 pre-training example 選擇句子 A 和 B 時,有 50% 的機率,B 是實際在 corpus 中緊接在 A 之後的句子(標記為 IsNext);另外 50% 的機率,B 是從 corpus 中隨機選取的一個句子(標記為 NotNext)。[CLS] 相對應的 final hidden state(也就是 aggregate sequence representation)會用來進行 NSP。

下圖中,我們對前半部的句子,挑選下一個句子,並組合成 training examples。圖中的上半部是挑選在實際 corpus 中,緊接在後的句子,因此標記為 IsNext。下半部是隨機從 corpus 中挑選的句子,標記為 NotNext

An example of NSP.
An example of NSP.

實作

BERT 使用 WordPiece,但為了簡化範例程式碼,我們單純地 tokenize 字,並且設定一個很小的 vocabulary,如下。

tkn2idx = {
    "[PAD]": 0, "[CLS]": 1, "[SEP]": 2, "[MASK]": 3,
    "i": 4, "like": 5, "dogs": 6, "cats": 7,
    "they": 8, "are": 9, "playful": 10,
    "[UNK]": 11,
}

idx2tkn = {v: k for k, v in tkn2idx.items()}


def tokenize(text):
    tokens = text.split()
    token_ids = [tkn2idx.get(t, tkn2idx["[UNK]"]) for t in tokens]
    return token_ids

然後,我們使用以下的 corpus。

corpus = [
    "i like dogs",
    "they are playful",
    "i like cats",
    "they are cute"
]

接下來,我們用以下程式碼來建立 pre-training dataset。在選擇句子對時,50% 的機率選擇下一個句子,50% 的機率隨機選擇一個句子。token_type_ids 用 0 表示在 input_ids 中該位子的 token 是屬於句子 A,而 1 表示屬於句子 B。另外,mlm_labels 用 -100 表示在 input_ids 中該位子的 token 沒有被 masked,若該位子在 input_ids 中被取代為 [MASK] token 的話,則該位子在 mlm_labels 則用被 masked 的 token。

def create_example_for_mlm_nsp(sentence_a, sentence_b, is_next, max_seq_len=12, mask_prob=0.15):
    cls_id = tkn2idx["[CLS]"]
    sep_id = tkn2idx["[SEP]"]
    mask_id = tkn2idx["[MASK]"]

    tokens_a = tokenize(sentence_a)
    tokens_b = tokenize(sentence_b)

    input_ids = [cls_id] + tokens_a + [sep_id] + tokens_b + [sep_id]
    token_type_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)

    if len(input_ids) > max_seq_len:
        input_ids = input_ids[:max_seq_len]
        token_type_ids = token_type_ids[:max_seq_len]

    # -100 for non-masked positions, and the original token for masked positions
    mlm_labels = [-100] * len(input_ids)

    num_to_mask = max(1, int((len(input_ids) - 3) * mask_prob))  # 3 for [CLS], [SEP], [SEP]
    candidate_mask_positions = [i for i, tid in enumerate(input_ids) if tid not in [cls_id, sep_id]]
    random.shuffle(candidate_mask_positions)
    mask_positions = candidate_mask_positions[:num_to_mask]

    for pos in mask_positions:
        mlm_labels[pos] = input_ids[pos]

        # BERT strategy: 80% replace with [MASK], 10% random, 10% keep
        r = random.random()
        if r < 0.8:
            input_ids[pos] = mask_id
        elif r < 0.9:
            input_ids[pos] = random.randint(4, len(tkn2idx) - 2)  # exclude special tokens
        else:
            pass

    nsp_label = 1 if is_next else 0
    return input_ids, token_type_ids, mlm_labels, nsp_label


def build_pretraining_dataset(corpus, num_examples):
    dataset = []
    n = len(corpus)
    for _ in range(num_examples):
        idx_a = random.randint(0, n - 1)
        sentence_a = corpus[idx_a]

        # 50%: pick a real next sentence; 50%: pick a random sentence
        if random.random() < 0.5:
            idx_b = (idx_a + 1) % n
            sentence_b = corpus[idx_b]
            is_next = True
        else:
            idx_b = random.randint(0, n - 1)
            while idx_b == idx_a:
                idx_b = random.randint(0, n - 1)
            sentence_b = corpus[idx_b]
            is_next = False

        input_ids, token_type_ids, mlm_labels, nsp_label = create_example_for_mlm_nsp(sentence_a, sentence_b, is_next)
        dataset.append((input_ids, token_type_ids, mlm_labels, nsp_label))

    return dataset

在之前的 Bert 程式碼中,Bert.forward() 最終輸出 output representations 和 [CLS] 對應的 final hidden state。對於 MLM task,我們希望模型可以預測出被遮罩的位子的 token。對於 NSP,我們希望模型可以預測出第二個句子是否是實際上的下一句。

因此在以下的程式碼中,我們在輸出層後,將 output representations 轉換為用來預測被遮罩的 token,並將 [CLS] 對應的 final hidden state 用來預設是否為下一句。

另外,由於要將 output representations 轉換為用來預測被遮罩的 token,因此我們將模型中的 bert.embeddings.word_embeddings.weight 設定給 predictions.weight

class BertForPreTraining(nn.Module):
    def __init__(
        self, vocab_size, token_type_size, max_position_embeddings, hidden_dim, num_layers, num_heads, d_ff,
        dropout_prob
    ):
        super(BertForPreTraining, self).__init__()
        self.bert = Bert(
            vocab_size, token_type_size, max_position_embeddings, hidden_dim, num_layers, num_heads, d_ff, dropout_prob
        )
        # Tying the MLM head's weight to the word embedding
        self.cls = PreTrainingHeads(vocab_size, hidden_dim, self.bert.embeddings.word_embeddings.weight)

    def forward(self, input_ids, token_type_ids=None, mask=None):
        """
        Pre-training BERT

        Args
            input_ids: (batch_size, seq_len)
            token_type_ids: (batch_size, seq_len)
            mask: (batch_size, seq_len)

        Returns
            prediction_scores: (batch_size, seq_len, vocab_size)
            seq_relationship_scores: (batch_size, 2)
        """

        sequence_output, pooled_output = self.bert(input_ids, token_type_ids, mask=mask)
        prediction_scores, seq_relationship_scores = self.cls(sequence_output, pooled_output)
        return prediction_scores, seq_relationship_scores


class BertForSequenceClassification(nn.Module):
    def __init__(self, bert, num_labels, hidden_dim):
        super(BertForSequenceClassification, self).__init__()
        self.bert = bert
        # A classification head: we typically use the [CLS] pooled output
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, input_ids, token_type_ids=None, mask=None, labels=None):
        """
        Sequence classification with BERT

        Args
            input_ids: (batch_size, seq_len)
            token_type_ids: (batch_size, seq_len)
            mask: (batch_size, seq_len)
            labels: (batch_size)

        Returns
            logits: (batch_size, num_classes)
            loss: (optional) Cross entropy loss
        """

        sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, mask=mask)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)

        return logits, loss

我們用以下程式碼來執行 pre-training。

def collate_pretraining_batch(examples):
    pad_id = tkn2idx["[PAD]"]
    max_len = max(len(ex[0]) for ex in examples)

    batch_input_ids = []
    batch_token_type_ids = []
    batch_mlm_labels = []
    batch_nsp_labels = []
    batch_mask = []

    for (input_ids, token_type_ids, mlm_labels, nsp_label) in examples:
        seq_len = len(input_ids)
        pad_len = max_len - seq_len
        batch_input_ids.append(input_ids + [pad_id] * pad_len)
        batch_token_type_ids.append(token_type_ids + [0] * pad_len)
        batch_mlm_labels.append(mlm_labels + [-100] * pad_len)
        batch_nsp_labels.append(nsp_label)
        batch_mask.append([1] * seq_len + [0] * pad_len)

    batch_input_ids = torch.tensor(batch_input_ids, dtype=torch.long)
    batch_token_type_ids = torch.tensor(batch_token_type_ids, dtype=torch.long)
    batch_mlm_labels = torch.tensor(batch_mlm_labels, dtype=torch.long)
    batch_nsp_labels = torch.tensor(batch_nsp_labels, dtype=torch.long)
    batch_mask = torch.tensor(batch_mask, dtype=torch.long)
    return batch_input_ids, batch_token_type_ids, batch_mlm_labels, batch_nsp_labels, batch_mask


def pretrain_bert():
    dataset = build_pretraining_dataset(corpus, num_examples=32)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_pretraining_batch)

    model = BertForPreTraining(
        vocab_size=len(tkn2idx),
        token_type_size=2,
        max_position_embeddings=64,
        hidden_dim=32,
        num_layers=2,
        num_heads=2,
        d_ff=64,
        dropout_prob=0.1,
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    EPOCHS = 100

    for epoch in range(EPOCHS):
        total_loss = 0
        for batch in dataloader:
            input_ids, token_type_ids, mlm_labels, nsp_labels, mask = batch
            optimizer.zero_grad()

            prediction_scores, seq_relationship_scores = model(input_ids, token_type_ids, mask)

            mlm_loss = F.cross_entropy(prediction_scores.view(-1, len(tkn2idx)), mlm_labels.view(-1), ignore_index=-100)
            nsp_loss = F.cross_entropy(seq_relationship_scores.view(-1, 2), nsp_labels.view(-1))
            loss = mlm_loss + nsp_loss

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}")

    return model

以下程式碼中,我們測試一下 pre-trained BERT。

def test_pretrain_bert(model):
    sent_a = "i like [MASK]"
    sent_b = "they are playful"

    input_ids, token_type_ids, mlm_labels, nsp_label = create_example_for_mlm_nsp(sent_a, sent_b, is_next=True)
    test_batch = collate_pretraining_batch([(input_ids, token_type_ids, mlm_labels, nsp_label)])
    input_ids_batch, token_type_ids_batch, mlm_labels_batch, nsp_labels_batch, mask_batch = test_batch

    model.eval()
    with torch.no_grad():
        prediction_scores, seq_relationship_scores = model(input_ids_batch, token_type_ids_batch, mask_batch)

    masked_index = (torch.tensor(input_ids) == tkn2idx["[MASK]"]).nonzero(as_tuple=True)[0]
    if len(masked_index) > 0:
        # We'll just look at the first masked token
        mask_position = masked_index[0].item()
        logits = prediction_scores[0, mask_position]  # shape [vocab_size]
        probs = F.softmax(logits, dim=-1)
        top5 = torch.topk(probs, 5)
        print("Top 5 predictions for [MASK]:")
        for prob, idx in zip(top5.values, top5.indices):
            print(f"  Token='{idx2tkn[idx.item()]}' prob={prob.item():.4f}")

    nsp_prob = F.softmax(seq_relationship_scores[0], dim=-1)
    print("NSP probabilities =", nsp_prob)

微調 BERT(Fine-tuning BERT)

在微調(fine-turning)的階段,我們會對一個 pre-trained BERT 來進行 fine-tuning,例如用我們剛剛 pre-trained BERT,或是用 Google pre-trained bert-base-uncasedbert-large-uncased。我們可以對一個 pre-trained BERT 來 fine-tune 一個特定的 downstream task。我們接來將展示如何 fine-tune BERT 成一個情感分類(sentiment classification)模型。

以下是 fine-tuning 用的資料。

# 1: positive, 0: negative
sentiment_data = [
    ("i like dogs", 1),
    ("i like cats", 1),
    ("they are playful", 1),
    ("they are bad", 0),  # 'bad' not in vocab, will become [UNK]
    ("i like [UNK]", 0),  # random negative label
]

然後,我們用以下的程式碼來建立 fine-tuning 用的 dataset。與建立 pre-training 的 dataset 時相似,不過我們這邊使用單一個句子,而不是句子對。

def create_example_for_classification(sentence):
    cls_id = tkn2idx["[CLS]"]
    sep_id = tkn2idx["[SEP]"]

    tokens = tokenize(sentence)

    input_ids = [cls_id] + tokens + [sep_id]
    token_type_ids = [0] * (len(tokens) + 2)

    return input_ids, token_type_ids


def build_sentiment_dataset(data):
    examples = []
    for sentence, label in data:
        input_ids, token_type_ids = create_example_for_classification(sentence)
        examples.append((input_ids, token_type_ids, label))
    return examples

相似於 pre-training task,sentiment classification 也需要一個特定的 layer 來處理 BERT 模型的輸出。我們的 sentiment classification 模型會預設句子是正面的(positive)或反面的(negative),所以它是對整個句子做出預測。因此,我們會使用 [CLS] 對應的輸出(也就是 aggregate sequence representation)來做預測。

class BertForSequenceClassification(nn.Module):
    def __init__(self, bert, num_labels, hidden_dim):
        super(BertForSequenceClassification, self).__init__()
        self.bert = bert
        # A classification head: we typically use the [CLS] pooled output
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, input_ids, token_type_ids=None, mask=None, labels=None):
        """
        Sequence classification with BERT

        Args
            input_ids: (batch_size, seq_len)
            token_type_ids: (batch_size, seq_len)
            mask: (batch_size, seq_len)
            labels: (batch_size)

        Returns
            logits: (batch_size, num_classes)
            loss: (optional) Cross entropy loss
        """

        sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, mask=mask)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)

        return logits, loss

我們用以下的程式碼來 fine-tune 一個 pre-trained BERT。

def collate_pretraining_batch(examples):
    pad_id = tkn2idx["[PAD]"]
    max_len = max(len(ex[0]) for ex in examples)

    batch_input_ids = []
    batch_token_type_ids = []
    batch_mlm_labels = []
    batch_nsp_labels = []
    batch_mask = []

    for (input_ids, token_type_ids, mlm_labels, nsp_label) in examples:
        seq_len = len(input_ids)
        pad_len = max_len - seq_len
        batch_input_ids.append(input_ids + [pad_id] * pad_len)
        batch_token_type_ids.append(token_type_ids + [0] * pad_len)
        batch_mlm_labels.append(mlm_labels + [-100] * pad_len)
        batch_nsp_labels.append(nsp_label)
        batch_mask.append([1] * seq_len + [0] * pad_len)

    batch_input_ids = torch.tensor(batch_input_ids, dtype=torch.long)
    batch_token_type_ids = torch.tensor(batch_token_type_ids, dtype=torch.long)
    batch_mlm_labels = torch.tensor(batch_mlm_labels, dtype=torch.long)
    batch_nsp_labels = torch.tensor(batch_nsp_labels, dtype=torch.long)
    batch_mask = torch.tensor(batch_mask, dtype=torch.long)
    return batch_input_ids, batch_token_type_ids, batch_mlm_labels, batch_nsp_labels, batch_mask


def pretrain_bert():
    dataset = build_pretraining_dataset(corpus, num_examples=32)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_pretraining_batch)

    model = BertForPreTraining(
        vocab_size=len(tkn2idx),
        token_type_size=2,
        max_position_embeddings=64,
        hidden_dim=32,
        num_layers=2,
        num_heads=2,
        d_ff=64,
        dropout_prob=0.1,
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    EPOCHS = 100

    for epoch in range(EPOCHS):
        total_loss = 0
        for batch in dataloader:
            input_ids, token_type_ids, mlm_labels, nsp_labels, mask = batch
            optimizer.zero_grad()

            prediction_scores, seq_relationship_scores = model(input_ids, token_type_ids, mask)

            mlm_loss = F.cross_entropy(prediction_scores.view(-1, len(tkn2idx)), mlm_labels.view(-1), ignore_index=-100)
            nsp_loss = F.cross_entropy(seq_relationship_scores.view(-1, 2), nsp_labels.view(-1))
            loss = mlm_loss + nsp_loss

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}")

    return model

最後,我們可以用以下程式碼來測試一下我們剛剛 fine-tuning 好的 sentiment classification 模型。

def test_fine_tune_bert(model):
    text = "i like dogs"

    input_ids, token_type_ids = create_example_for_classification(text)
    mask = [1] * len(input_ids)

    input_ids_tensor = torch.tensor([input_ids], dtype=torch.long)
    token_type_ids_tensor = torch.tensor([token_type_ids], dtype=torch.long)
    mask_tensor = torch.tensor([mask], dtype=torch.long)

    model.eval()
    with torch.no_grad():
        logits, loss = model(input_ids_tensor, token_type_ids_tensor, mask_tensor)

    probs = F.softmax(logits, dim=-1)
    predicted_label = torch.argmax(probs, dim=-1).item()

    print("Probabilities =", probs)
    print("Predicted label =", predicted_label)

你可以用以下程式碼來執行 pre-training 和 fine-tuning。

if __name__ == "__main__":
    pretrain_model = pretrain_bert()
    test_pretrain_bert(pretrain_model)
    fine_tune_model = fine_tune_bert(pretrain_model.bert)
    test_fine_tune_bert(fine_tune_model)

結語

BERT 不僅是 NLP 領域的技術創新,更是推動整個人工智慧語言理解邁入新紀元的重要里程碑。透過 bidirectional Transformer 架構、 pre-training 與 fine-tuning,BERT 為各種語言任務提供了前所未有的準確性與彈性。

參考

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *

You May Also Like