Bidirectional Encoder Representations from Transformers (BERT) is a pre-training technology for natural language processing proposed by Google AI in 2018. BERT significantly advances the state of natural language processing by providing a deeper contextual understanding of language.
The complete code for this chapter can be found in .
Table of Contents
BERT Architecture
The full name of BERT is bidirectional encoder representations from Transformers. As the name suggests, the core architecture of BERT is the encoder of Transformers. If you are not familiar with Transformers, please refer to the following article first.
The following figure shows the architecture of BERT and Transformers. It can be clearly seen that BERT consists of the encoder of Transformers plus an output layer. BERT pre-trains deep bidirectional representations from an unlabeled corpus by simultaneously considering both left and right contexts in all layers. This is what the encoder of Transformers does. The representations capture different semantic aspects of the input sequence. Transformers pass these representations into the decoder’s cross multi-head attention to preset the next output.
However, the pre-trained BERT model is just an encoder for Transformers and only outputs this representation. We can use the output representations for downstream tasks such as question answering and language inference. We simply fine-tune the pre-trained BERT model by adding an additional output layer, which can produce models for these downstream tasks without making major modifications to the architecture for specific tasks. Therefore, BERT consists of two stages: pre-training and fine-tuning.
Input/Output Representations
In order to let BERT handle a variety of downstream tasks, the input sequence of BERT can be a sentence or a pair of sentences (such as <question, answer>). BERT uses WordPiece embeddings proposed by Google in 2016 with a vocabulary size of 30,000 tokens. The first token of each input sequence is always a special classification token ([CLS]
). The final hidden state corresponding to this [CLS]
will be used as the aggregate sequence representation of the sequence in the classification task.
When the input sequence is a pair of sentences, we need to merge the two sentences and distinguish the sentences in the following two ways:
- First: Use a special [
SEP]
token to separate them; - Second: Add a learned embedding to each token to indicate whether the token belongs to sentence A or sentence B.
For a given token, its input representations are the sum of the token embedding, segment embedding, and position embedding, as shown in the figure below.
BERT Implementation
Below is the implementation of the BERT model. If you don’t know Transformers yet or have trouble understanding the following implementation, please refer to the following article first.
This implementation is exactly the same as the Transformer encoder, with two exceptions:
- In Embeddings, there is an additional
token_type_embeddings
. This is used to distinguish sentence A from sentence B. - At the output, there is an additional pool layer. This pool layer captures the first token of the output representations. This token corresponds to the
[CLS]
token in the input sequence. As mentioned before, it will be used as aggregate sequence representation in the classification task.
Finally, the BERT model outputs representations and aggregate sequence representations.
class Embeddings(nn.Module): def __init__(self, vocab_size, token_type_size, max_position_embeddings, hidden_dim, dropout_prob): super(Embeddings, self).__init__() self.word_embeddings = nn.Embedding(vocab_size, hidden_dim) self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_dim) self.token_type_embeddings = nn.Embedding(token_type_size, hidden_dim) self.norm = nn.LayerNorm(hidden_dim, eps=1e-12) self.dropout = nn.Dropout(dropout_prob) def forward(self, input_ids, token_type_ids=None): """ Compute the embeddings for the input tokens. Args x: (batch_size, seq_len) token_type_ids: (batch_size, seq_len) Returns embeddings: (batch_size, seq_len, hidden_dim) """ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) seq_len = input_ids.size(1) position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (1, seq_len) -> (batch_size, seq_len) word_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = word_embeddings + position_embeddings + token_type_embeddings embeddings = self.norm(embeddings) embeddings = self.dropout(embeddings) return embeddings class MultiHeadAttention(nn.Module): def __init__(self, num_heads, hidden_dim, dropout_prob): super(MultiHeadAttention, self).__init__() assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" self.num_heads = num_heads self.head_size = hidden_dim // num_heads self.all_head_size = hidden_dim self.query = nn.Linear(hidden_dim, self.all_head_size, bias=False) self.key = nn.Linear(hidden_dim, self.all_head_size, bias=False) self.value = nn.Linear(hidden_dim, self.all_head_size, bias=False) self.output = nn.Linear(hidden_dim, hidden_dim, bias=False) self.norm = nn.LayerNorm(hidden_dim, eps=1e-12) self.dropout = nn.Dropout(dropout_prob) def forward(self, hidden_states, mask=None): """ Multi-head attention forward pass. Args hidden_states: (batch_size, seq_len, hidden_dim) mask: (batch_size, 1, 1, seq_len) 0 for real tokens, -inf for padding tokens Returns hidden_states: (batch_size, seq_len, hidden_dim) """ query = self.transpose_for_scores(self.query(hidden_states)) # (batch_size, num_heads, seq_len, head_size) key = self.transpose_for_scores(self.key(hidden_states)) # (batch_size, num_heads, seq_len, head_size) value = self.transpose_for_scores(self.value(hidden_states)) # (batch_size, num_heads, seq_len, head_size) # Scaled dot-product attention scores = query @ key.transpose(-2, -1) / math.sqrt(self.head_size) # (batch_size, num_heads, seq_len, seq_len) if mask is not None: scores = scores + mask attention_weights = F.softmax(scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len) attention_weights = self.dropout(attention_weights) attention = attention_weights @ value # (batch_size, num_heads, seq_len, head_size) # Concatenate heads attention = attention.transpose(1, 2).contiguous() # (batch_size, seq_len, num_heads, head_size) new_shape = attention.size()[:-2] + (self.all_head_size,) attention = attention.view(*new_shape) # (batch_size, seq_len, all_head_size) # Linear projection projection_output = self.output(attention) # (batch_size, seq_len, hidden_dim) projection_output = self.dropout(projection_output) hidden_states = self.norm(hidden_states + projection_output) return hidden_states def transpose_for_scores(self, x): """ Args x: (batch_size, seq_len, all_head_size) Returns (batch_size, num_heads, seq_len, head_size) """ new_x_shape = x.size()[:-1] + (self.num_heads, self.head_size) x = x.view(*new_x_shape) # (batch_size, seq_len, num_heads, head_size) return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, head_size) class PositionwiseFeedForward(nn.Module): def __init__(self, hidden_dim, d_ff): super(PositionwiseFeedForward, self).__init__() self.linear1 = nn.Linear(hidden_dim, d_ff, bias=True) self.linear2 = nn.Linear(d_ff, hidden_dim, bias=True) self.activation = nn.GELU() def forward(self, hidden_states): """ Feed-forward network forward pass. Args hidden_states: (batch_size, seq_len, hidden_dim) Returns hidden_states: (batch_size, seq_len, hidden_dim) """ hidden_states = self.linear2(self.activation(self.linear1(hidden_states))) return hidden_states class EncoderLayer(nn.Module): def __init__(self, num_heads, hidden_dim, d_ff, dropout_prob): super(EncoderLayer, self).__init__() self.multi_head_attention = MultiHeadAttention(num_heads, hidden_dim, dropout_prob) self.ffn = PositionwiseFeedForward(hidden_dim, d_ff) self.norm = nn.LayerNorm(hidden_dim, eps=1e-12) self.dropout = nn.Dropout(dropout_prob) def forward(self, hidden_states, mask=None): """ Encoder layer forward pass. Args hidden_states: (batch_size, seq_len, hidden_dim) mask: (batch_size, 1, seq_len) 0 for real tokens, -inf for padding tokens Returns hidden_states: (batch_size, seq_len, hidden_dim) """ # Multi-head attention attention_output = self.multi_head_attention(hidden_states, mask=mask) # Feed-forward network ffn_output = self.ffn(attention_output) ffn_output = self.dropout(ffn_output) hidden_states = self.norm(hidden_states + ffn_output) return hidden_states class Encoder(nn.Module): def __init__(self, hidden_dim, num_layers, num_heads, d_ff, dropout_prob): super(Encoder, self).__init__() self.layers = nn.ModuleList( [EncoderLayer(num_heads, hidden_dim, d_ff, dropout_prob) for _ in range(num_layers)] ) def forward(self, hidden_states, mask=None): """ Encoder forward pass. Args hidden_states: (batch_size, seq_len, hidden_dim) mask: (batch_size, 1, seq_len) 0 for real tokens, -inf for padding tokens Returns hidden_states: (batch_size, seq_len, hidden_dim) """ for layer in self.layers: hidden_states = layer(hidden_states, mask=mask) return hidden_states class Pooler(nn.Module): def __init__(self, hidden_dim): super(Pooler, self).__init__() self.linear = nn.Linear(hidden_dim, hidden_dim) def forward(self, hidden_states): """ Pooler forward pass. Args hidden_states: (batch_size, seq_len, hidden_dim) Returns pooled_output: (batch_size, hidden_dim) """ first_token_tensor = hidden_states[:, 0] pooled_output = self.linear(first_token_tensor) pooled_output = F.tanh(pooled_output) return pooled_output class Bert(nn.Module): def __init__( self, vocab_size, token_type_size, max_position_embeddings, hidden_dim, num_layers, num_heads, d_ff, dropout_prob ): super(Bert, self).__init__() self.embeddings = Embeddings(vocab_size, token_type_size, max_position_embeddings, hidden_dim, dropout_prob) self.encoder = Encoder(hidden_dim, num_layers, num_heads, d_ff, dropout_prob) self.pooler = Pooler(hidden_dim) def forward(self, input_ids, token_type_ids=None, mask=None): """ Forward pass for the BERT model. Args input_ids: (batch_size, seq_len) token_type_ids: (batch_size, seq_len) mask: (batch_size, seq_len) Returns encoder_output: (batch_size, seq_len, hidden_dim) pooled_output: (batch_size, hidden_dim) """ if mask is not None: extended_mask = mask.unsqueeze(1).unsqueeze(2) extended_mask = extended_mask.to(dtype=torch.float32) # Convert 1 -> 0, 0 -> large negative (mask out) extended_mask = (1.0 - extended_mask) * -10000.0 else: extended_mask = None embedding_output = self.embeddings(input_ids, token_type_ids) encoder_output = self.encoder(embedding_output, mask=extended_mask) pooled_output = self.pooler(encoder_output) return encoder_output, pooled_output
Pre-training BERT
Pre-training BERT includes two tasks, one is Masked Language Modeling (MLM) and the other is Next Sentence Prediction (NSP). Below we will introduce the details of these two tasks respectively.
Masked Language Modeling, MLM
We can reasonably infer that a deep bidirectional model must be more powerful than a simple left-to-right model or a shallow concatenation of left-to-right and right-to-left models. However, traditional conditional language models can only be trained in a left-to-right or right-to-left manner, because if bidirectional conditioning is allowed, the model will indirectly “see” the words it wants to predict, which makes it easy for the model to directly predict the target words from multiple layers of context information.
In order to train deep bidirectional representations, we directly randomly mask some tokens in the input sequence and then let the model predict those masked tokens. This process is called the Masked Language Model (MLM) task. The final hidden states corresponding to the masked tokens are fed into a softmax
function of the output layer to make predictions for the entire vocabulary, similar to the standard language model.
Although this method allows us to obtain a bidirectional pre-trained model, its disadvantage is that there are certain differences between the pre-training stage and the fine-tuning stage, because the special [MASK]
token does not appear in the actual fine-tuning. To reduce the impact of this problem, we do not always directly replace the token to be masked with the [MASK]
token.
When generating training data, we randomly select 15% of the token positions as prediction targets. If the i-th token is selected, we will:
- Keep the i-th token unchanged with a probability of 10%.
- Replace the i-th token with the
[MASK]
token with a probability of 80%. - Replace the i-th token with a random token with a probability of 10%.
The following figure shows how to combine the two clauses and generate training examples for MLM in the above way.
Next Sentence Prediction, NSP
Many downstream tasks, such as question answering and language inference, rely on understanding the relationship between two sentences, which is not directly modeled in traditional language model pre-training. In order to train a model that can understand sentence relationships, we also need to pre-train the Next Sentence Prediction (NSP) task.
When we select sentences A and B for each pre-training example, there is a 50% chance that B is the sentence that actually follows A in the corpus (marked as IsNext
); the other 50% chance that B is a sentence randomly selected from the corpus (marked as NotNext
). The corresponding final hidden state to [CLS]
(that is, aggregate sequence representation) will be used for NSP.
In the figure below, we select the next sentence from the first half of the sentence and combine them into training examples. The upper part of the figure selects the sentence that comes immediately after in the actual corpus, so it is marked as IsNext
. The lower half is a sentence randomly selected from the corpus, labeled NotNext
.
Implementation
BERT uses WordPiece, but to simplify the example code, we simply tokenize the words and set a very small vocabulary as follows.
tkn2idx = { "[PAD]": 0, "[CLS]": 1, "[SEP]": 2, "[MASK]": 3, "i": 4, "like": 5, "dogs": 6, "cats": 7, "they": 8, "are": 9, "playful": 10, "[UNK]": 11, } idx2tkn = {v: k for k, v in tkn2idx.items()} def tokenize(text): tokens = text.split() token_ids = [tkn2idx.get(t, tkn2idx["[UNK]"]) for t in tokens] return token_ids
We then use the following corpus.
corpus = [ "i like dogs", "they are playful", "i like cats", "they are cute" ]
Next, we use the following code to create a pre-training dataset. When selecting a sentence pair, there is a 50% chance of selecting the next sentence and a 50% chance of selecting a sentence at random. token_type_ids
uses 0 to indicate that the token at that position in input_ids
belongs to sentence A, and 1 to indicate that it belongs to sentence B. In addition, mlm_labels
uses -100 to indicate that the token at this position in input_ids
is not masked. If this position is replaced by the [MASK]
token in input_ids, the masked token will be used in mlm_labels
.
def create_example_for_mlm_nsp(sentence_a, sentence_b, is_next, max_seq_len=12, mask_prob=0.15): cls_id = tkn2idx["[CLS]"] sep_id = tkn2idx["[SEP]"] mask_id = tkn2idx["[MASK]"] tokens_a = tokenize(sentence_a) tokens_b = tokenize(sentence_b) input_ids = [cls_id] + tokens_a + [sep_id] + tokens_b + [sep_id] token_type_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1) if len(input_ids) > max_seq_len: input_ids = input_ids[:max_seq_len] token_type_ids = token_type_ids[:max_seq_len] # -100 for non-masked positions, and the original token for masked positions mlm_labels = [-100] * len(input_ids) num_to_mask = max(1, int((len(input_ids) - 3) * mask_prob)) # 3 for [CLS], [SEP], [SEP] candidate_mask_positions = [i for i, tid in enumerate(input_ids) if tid not in [cls_id, sep_id]] random.shuffle(candidate_mask_positions) mask_positions = candidate_mask_positions[:num_to_mask] for pos in mask_positions: mlm_labels[pos] = input_ids[pos] # BERT strategy: 80% replace with [MASK], 10% random, 10% keep r = random.random() if r < 0.8: input_ids[pos] = mask_id elif r < 0.9: input_ids[pos] = random.randint(4, len(tkn2idx) - 2) # exclude special tokens else: pass nsp_label = 1 if is_next else 0 return input_ids, token_type_ids, mlm_labels, nsp_label def build_pretraining_dataset(corpus, num_examples): dataset = [] n = len(corpus) for _ in range(num_examples): idx_a = random.randint(0, n - 1) sentence_a = corpus[idx_a] # 50%: pick a real next sentence; 50%: pick a random sentence if random.random() < 0.5: idx_b = (idx_a + 1) % n sentence_b = corpus[idx_b] is_next = True else: idx_b = random.randint(0, n - 1) while idx_b == idx_a: idx_b = random.randint(0, n - 1) sentence_b = corpus[idx_b] is_next = False input_ids, token_type_ids, mlm_labels, nsp_label = create_example_for_mlm_nsp(sentence_a, sentence_b, is_next) dataset.append((input_ids, token_type_ids, mlm_labels, nsp_label)) return dataset
In the previous Bert
code, Bert.forward()
finally outputs the output representations and the final hidden state corresponding to [CLS]
. For the MLM task, we hope that the model can predict the token of the masked position. For NSP, we want the model to predict whether the second sentence is actually the next sentence.
Therefore, in the following code, after the output layer, we convert the output representations to predict the masked token, and use the final hidden state corresponding to [CLS]
to preset whether it is the next sentence.
In addition, since the output representations are converted to predict the masked token, we set bert.embeddings.word_embeddings.weight
in the model to predictions.weight
.
class BertForPreTraining(nn.Module): def __init__( self, vocab_size, token_type_size, max_position_embeddings, hidden_dim, num_layers, num_heads, d_ff, dropout_prob ): super(BertForPreTraining, self).__init__() self.bert = Bert( vocab_size, token_type_size, max_position_embeddings, hidden_dim, num_layers, num_heads, d_ff, dropout_prob ) # Tying the MLM head's weight to the word embedding self.cls = PreTrainingHeads(vocab_size, hidden_dim, self.bert.embeddings.word_embeddings.weight) def forward(self, input_ids, token_type_ids=None, mask=None): """ Pre-training BERT Args input_ids: (batch_size, seq_len) token_type_ids: (batch_size, seq_len) mask: (batch_size, seq_len) Returns prediction_scores: (batch_size, seq_len, vocab_size) seq_relationship_scores: (batch_size, 2) """ sequence_output, pooled_output = self.bert(input_ids, token_type_ids, mask=mask) prediction_scores, seq_relationship_scores = self.cls(sequence_output, pooled_output) return prediction_scores, seq_relationship_scores class BertForSequenceClassification(nn.Module): def __init__(self, bert, num_labels, hidden_dim): super(BertForSequenceClassification, self).__init__() self.bert = bert # A classification head: we typically use the [CLS] pooled output self.classifier = nn.Linear(hidden_dim, num_labels) def forward(self, input_ids, token_type_ids=None, mask=None, labels=None): """ Sequence classification with BERT Args input_ids: (batch_size, seq_len) token_type_ids: (batch_size, seq_len) mask: (batch_size, seq_len) labels: (batch_size) Returns logits: (batch_size, num_classes) loss: (optional) Cross entropy loss """ sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, mask=mask) logits = self.classifier(pooled_output) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) return logits, loss
We use the following code to perform pre-training.
def collate_pretraining_batch(examples): pad_id = tkn2idx["[PAD]"] max_len = max(len(ex[0]) for ex in examples) batch_input_ids = [] batch_token_type_ids = [] batch_mlm_labels = [] batch_nsp_labels = [] batch_mask = [] for (input_ids, token_type_ids, mlm_labels, nsp_label) in examples: seq_len = len(input_ids) pad_len = max_len - seq_len batch_input_ids.append(input_ids + [pad_id] * pad_len) batch_token_type_ids.append(token_type_ids + [0] * pad_len) batch_mlm_labels.append(mlm_labels + [-100] * pad_len) batch_nsp_labels.append(nsp_label) batch_mask.append([1] * seq_len + [0] * pad_len) batch_input_ids = torch.tensor(batch_input_ids, dtype=torch.long) batch_token_type_ids = torch.tensor(batch_token_type_ids, dtype=torch.long) batch_mlm_labels = torch.tensor(batch_mlm_labels, dtype=torch.long) batch_nsp_labels = torch.tensor(batch_nsp_labels, dtype=torch.long) batch_mask = torch.tensor(batch_mask, dtype=torch.long) return batch_input_ids, batch_token_type_ids, batch_mlm_labels, batch_nsp_labels, batch_mask def pretrain_bert(): dataset = build_pretraining_dataset(corpus, num_examples=32) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_pretraining_batch) model = BertForPreTraining( vocab_size=len(tkn2idx), token_type_size=2, max_position_embeddings=64, hidden_dim=32, num_layers=2, num_heads=2, d_ff=64, dropout_prob=0.1, ) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) model.train() EPOCHS = 100 for epoch in range(EPOCHS): total_loss = 0 for batch in dataloader: input_ids, token_type_ids, mlm_labels, nsp_labels, mask = batch optimizer.zero_grad() prediction_scores, seq_relationship_scores = model(input_ids, token_type_ids, mask) mlm_loss = F.cross_entropy(prediction_scores.view(-1, len(tkn2idx)), mlm_labels.view(-1), ignore_index=-100) nsp_loss = F.cross_entropy(seq_relationship_scores.view(-1, 2), nsp_labels.view(-1)) loss = mlm_loss + nsp_loss loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(dataloader) if (epoch + 1) % 10 == 0: print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}") return model
In the following code, we test the pre-trained BERT.
def test_pretrain_bert(model): sent_a = "i like [MASK]" sent_b = "they are playful" input_ids, token_type_ids, mlm_labels, nsp_label = create_example_for_mlm_nsp(sent_a, sent_b, is_next=True) test_batch = collate_pretraining_batch([(input_ids, token_type_ids, mlm_labels, nsp_label)]) input_ids_batch, token_type_ids_batch, mlm_labels_batch, nsp_labels_batch, mask_batch = test_batch model.eval() with torch.no_grad(): prediction_scores, seq_relationship_scores = model(input_ids_batch, token_type_ids_batch, mask_batch) masked_index = (torch.tensor(input_ids) == tkn2idx["[MASK]"]).nonzero(as_tuple=True)[0] if len(masked_index) > 0: # We'll just look at the first masked token mask_position = masked_index[0].item() logits = prediction_scores[0, mask_position] # shape [vocab_size] probs = F.softmax(logits, dim=-1) top5 = torch.topk(probs, 5) print("Top 5 predictions for [MASK]:") for prob, idx in zip(top5.values, top5.indices): print(f" Token='{idx2tkn[idx.item()]}' prob={prob.item():.4f}") nsp_prob = F.softmax(seq_relationship_scores[0], dim=-1) print("NSP probabilities =", nsp_prob)
Fine-tuning BERT
In the fine-tuning stage, we will fine-tune a pre-trained BERT, for example, we just pre-trained BERT, or Google pre-trained bert-base-uncased or bert-large-uncased. We can fine-tune a pre-trained BERT for a specific downstream task. We will next show how to fine-tune BERT into a sentiment classification model.
The following is the information used for fine-tuning.
# 1: positive, 0: negative sentiment_data = [ ("i like dogs", 1), ("i like cats", 1), ("they are playful", 1), ("they are bad", 0), # 'bad' not in vocab, will become [UNK] ("i like [UNK]", 0), # random negative label ]
Then, we use the following code to create a dataset for fine-tuning. This is similar to building a pre-training dataset, but we use a single sentence instead of a sentence pair.
def create_example_for_classification(sentence): cls_id = tkn2idx["[CLS]"] sep_id = tkn2idx["[SEP]"] tokens = tokenize(sentence) input_ids = [cls_id] + tokens + [sep_id] token_type_ids = [0] * (len(tokens) + 2) return input_ids, token_type_ids def build_sentiment_dataset(data): examples = [] for sentence, label in data: input_ids, token_type_ids = create_example_for_classification(sentence) examples.append((input_ids, token_type_ids, label)) return examples
Similar to the pre-training task, sentiment classification also requires a specific layer to process the output of the BERT model. Our sentiment classification model assumes that a sentence is positive or negative, so it makes a prediction for the entire sentence. Therefore, we will use the corresponding output of [CLS]
(that is, aggregate sequence representation) to make predictions.
class BertForSequenceClassification(nn.Module): def __init__(self, bert, num_labels, hidden_dim): super(BertForSequenceClassification, self).__init__() self.bert = bert # A classification head: we typically use the [CLS] pooled output self.classifier = nn.Linear(hidden_dim, num_labels) def forward(self, input_ids, token_type_ids=None, mask=None, labels=None): """ Sequence classification with BERT Args input_ids: (batch_size, seq_len) token_type_ids: (batch_size, seq_len) mask: (batch_size, seq_len) labels: (batch_size) Returns logits: (batch_size, num_classes) loss: (optional) Cross entropy loss """ sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, mask=mask) logits = self.classifier(pooled_output) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) return logits, loss
We use the following code to fine-tune a pre-trained BERT.
def collate_pretraining_batch(examples): pad_id = tkn2idx["[PAD]"] max_len = max(len(ex[0]) for ex in examples) batch_input_ids = [] batch_token_type_ids = [] batch_mlm_labels = [] batch_nsp_labels = [] batch_mask = [] for (input_ids, token_type_ids, mlm_labels, nsp_label) in examples: seq_len = len(input_ids) pad_len = max_len - seq_len batch_input_ids.append(input_ids + [pad_id] * pad_len) batch_token_type_ids.append(token_type_ids + [0] * pad_len) batch_mlm_labels.append(mlm_labels + [-100] * pad_len) batch_nsp_labels.append(nsp_label) batch_mask.append([1] * seq_len + [0] * pad_len) batch_input_ids = torch.tensor(batch_input_ids, dtype=torch.long) batch_token_type_ids = torch.tensor(batch_token_type_ids, dtype=torch.long) batch_mlm_labels = torch.tensor(batch_mlm_labels, dtype=torch.long) batch_nsp_labels = torch.tensor(batch_nsp_labels, dtype=torch.long) batch_mask = torch.tensor(batch_mask, dtype=torch.long) return batch_input_ids, batch_token_type_ids, batch_mlm_labels, batch_nsp_labels, batch_mask def pretrain_bert(): dataset = build_pretraining_dataset(corpus, num_examples=32) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_pretraining_batch) model = BertForPreTraining( vocab_size=len(tkn2idx), token_type_size=2, max_position_embeddings=64, hidden_dim=32, num_layers=2, num_heads=2, d_ff=64, dropout_prob=0.1, ) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) model.train() EPOCHS = 100 for epoch in range(EPOCHS): total_loss = 0 for batch in dataloader: input_ids, token_type_ids, mlm_labels, nsp_labels, mask = batch optimizer.zero_grad() prediction_scores, seq_relationship_scores = model(input_ids, token_type_ids, mask) mlm_loss = F.cross_entropy(prediction_scores.view(-1, len(tkn2idx)), mlm_labels.view(-1), ignore_index=-100) nsp_loss = F.cross_entropy(seq_relationship_scores.view(-1, 2), nsp_labels.view(-1)) loss = mlm_loss + nsp_loss loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(dataloader) if (epoch + 1) % 10 == 0: print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}") return model
Finally, we can use the following code to test our fine-tuned sentiment classification model.
def test_fine_tune_bert(model): text = "i like dogs" input_ids, token_type_ids = create_example_for_classification(text) mask = [1] * len(input_ids) input_ids_tensor = torch.tensor([input_ids], dtype=torch.long) token_type_ids_tensor = torch.tensor([token_type_ids], dtype=torch.long) mask_tensor = torch.tensor([mask], dtype=torch.long) model.eval() with torch.no_grad(): logits, loss = model(input_ids_tensor, token_type_ids_tensor, mask_tensor) probs = F.softmax(logits, dim=-1) predicted_label = torch.argmax(probs, dim=-1).item() print("Probabilities =", probs) print("Predicted label =", predicted_label)
You can use the following code to perform pre-training and fine-tuning.
if __name__ == "__main__": pretrain_model = pretrain_bert() test_pretrain_bert(pretrain_model) fine_tune_model = fine_tune_bert(pretrain_model.bert) test_fine_tune_bert(fine_tune_model)
Conclusion
BERT is not only a technological innovation in the field of NLP, but also an important milestone in pushing the entire artificial intelligence language understanding into a new era. Through a bidirectional Transformer architecture, pre-training, and fine-tuning, BERT provides unprecedented accuracy and flexibility for a variety of language tasks.
References
- Jacob Devlin, Ming-Wei Change, Kenton Lee, and Kristina Toutanova. 2019. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In North American Association for Computational Linguistics (NAACL).
- BERT Source Code: https://github.com/google-research/bert.