import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')

1. BERT vs GPTΒΆ

Key DifferencesΒΆ

Feature

BERT

GPT

Direction

Bidirectional

Unidirectional (left-to-right)

Objective

Masked LM + NSP

Next token prediction

Architecture

Encoder only

Decoder only

Use case

Understanding

Generation

BERT: Bidirectional Encoder RepresentationsΒΆ

Core idea: See context from both directions simultaneously!

Example: β€œThe cat sat on the ___”

  • GPT: Only sees β€œThe cat sat on the”

  • BERT: Sees full context (masked token)

Pre-training TasksΒΆ

  1. Masked Language Modeling (MLM)

    • Mask 15% of tokens

    • Predict masked tokens

  2. Next Sentence Prediction (NSP)

    • Given sentence A, is B next?

    • Binary classification

1.5. Masked Language Modeling: Deep DiveΒΆ

The Masking StrategyΒΆ

BERT masks 15% of tokens, but with a clever trick:

For each masked position:

  • 80% β†’ Replace with [MASK]

  • 10% β†’ Replace with random token

  • 10% β†’ Keep original token

Why this complexity?

Problem with 100% [MASK]:

  • Fine-tuning never sees [MASK] token

  • Mismatch between pre-training and fine-tuning

Solution:

  • 80% [MASK]: Main training signal

  • 10% random: Forces model to check all tokens

  • 10% unchanged: Reduces pre-train/fine-tune mismatch

Mathematical FormulationΒΆ

MLM Objective:

\[\mathcal{L}_{\text{MLM}} = -\mathbb{E}_{x \sim D} \left[ \sum_{i \in M} \log P(x_i | x_{\backslash M}) \right]\]

where:

  • \(M\) is the set of masked positions (15% of tokens)

  • \(x_{\backslash M}\) is the sequence with masked positions

  • Model predicts \(x_i\) for each \(i \in M\)

Bidirectional context:

\[P(x_i | x_{\backslash M}) = \text{softmax}(W \cdot h_i)\]

where \(h_i\) is computed using both left and right context!

Comparison with Other ObjectivesΒΆ

Method

Objective

Context

Training Signals per Sequence

BERT (MLM)

Predict masked (15%)

Bidirectional

\(0.15 \times L\)

GPT (AR)

Predict next token

Left-only

\(L\)

ELECTRA

Detect replaced tokens

Bidirectional

\(L\) (all positions)

XLNet (PLM)

Permutation LM

Bidirectional

\(L\)

Trade-off:

  • BERT: Fewer training signals but bidirectional

  • GPT: More signals but unidirectional

  • ELECTRA: Best of both worlds!

ELECTRA: Efficiently Learning an EncoderΒΆ

Improvement over BERT:

Instead of predicting masked tokens, detect replaced tokens:

  1. Generator: Small MLM model replaces tokens

  2. Discriminator: Detect which tokens were replaced

\[\mathcal{L}_{\text{ELECTRA}} = -\sum_{i=1}^L \mathbb{1}(x_i = \tilde{x}_i) \log D(x, i) + \mathbb{1}(x_i \neq \tilde{x}_i) \log(1 - D(x, i))\]

where \(\tilde{x}_i\) is generator output, \(D(x, i)\) is discriminator.

Advantages: βœ… Trains on all positions (not just 15%) βœ… More sample efficient βœ… Smaller models reach BERT-large performance

Next Sentence Prediction (NSP)ΒΆ

Objective: Given sentences A and B, is B the actual next sentence?

Training data:

  • 50% positive: B actually follows A

  • 50% negative: B is random sentence

\[\mathcal{L}_{\text{NSP}} = -\log P(\text{IsNext} | [\text{CLS}], A, B)\]

Criticism: Too easy! Model often just uses topic matching.

RoBERTa finding: NSP might hurt performance!

Better Alternative: Sentence Order Prediction (SOP)ΒΆ

Used in: ALBERT, StructBERT

Instead of random sentence:

  • Positive: A then B (correct order)

  • Negative: B then A (swapped order)

Harder task: Requires understanding coherence, not just topics.

Masking PatternsΒΆ

Whole Word Masking:

Original BERT: Subword masking

Original: "playing"
Tokens: ["play", "##ing"]
BERT: ["[MASK]", "##ing"]  ← Leaks information!

Whole word masking:

["[MASK]", "[MASK]"]  ← Harder task

Span Masking (SpanBERT):

Mask contiguous spans instead of individual tokens:

Original: "The cat sat on the mat"
Masked: "The [MASK] [MASK] [MASK] the mat"
         (span of length 3)

Span lengths sampled from geometric distribution.

Advantages:

  • More challenging task

  • Better for tasks requiring span understanding

  • Improves downstream performance

Dynamic Masking (RoBERTa)ΒΆ

BERT: Static masking (same mask every epoch)

RoBERTa: Dynamic masking (new mask each epoch)

Benefits:

  • Model sees more diverse training signals

  • Prevents overfitting to specific masks

  • Small but consistent improvement

Pre-training EfficiencyΒΆ

Sample efficiency comparison:

Method

Compute

Data

Performance

BERT-base

1Γ—

16GB

Baseline

RoBERTa

4Γ—

160GB

+2% GLUE

ELECTRA-base

0.25Γ—

16GB

= BERT-large

DeBERTa

1.5Γ—

160GB

+4% GLUE

Key insight: Training longer on more data consistently helps!

Practical TipsΒΆ

Masking probability:

  • Standard: 15%

  • Can go up to 40% for some tasks (T5)

  • Lower (10%) for domain adaptation

Batch size:

  • BERT: 256 sequences

  • RoBERTa: 8192 (larger is better!)

  • Gradient accumulation if GPU memory limited

Sequence length:

  • Start with 128 tokens (faster)

  • Then 512 tokens (90% of training)

  • Longer sequences for long-form tasks

BERT Architecture ComponentsΒΆ

BERT’s architecture is a stack of bidirectional Transformer encoder layers, meaning each token attends to all other tokens in the sequence (both left and right context). The input representation for each token is the sum of three embeddings: token embedding (the word or subword identity), segment embedding (which sentence the token belongs to, for sentence-pair tasks), and positional embedding (the position index in the sequence). This three-way embedding scheme allows BERT to handle diverse NLP tasks – from single-sentence classification to question-answering over sentence pairs – within a unified architecture. The bidirectional attention is what distinguishes BERT from autoregressive models like GPT, giving it a deeper contextual understanding of each token.

2.5. Segment Embeddings and Special TokensΒΆ

Segment EmbeddingsΒΆ

BERT uses segment embeddings to distinguish sentence pairs:

\[\text{Input} = \text{Token Emb} + \text{Position Emb} + \text{Segment Emb}\]

Example:

Sentence A: "The cat sat"
Sentence B: "on the mat"

Input: [CLS] The cat sat [SEP] on the mat [SEP]
Segments: 0    0   0   0    0    1  1   1    1

Segment embedding: $\(s_i = \begin{cases} \mathbf{s}_A & \text{if token from sentence A} \\ \mathbf{s}_B & \text{if token from sentence B} \end{cases}\)$

where \(\mathbf{s}_A, \mathbf{s}_B \in \mathbb{R}^d\) are learned vectors.

Special TokensΒΆ

[CLS] - Classification Token:

  • Always first token

  • Aggregates sequence-level information

  • Used for classification tasks

Mathematical intuition:

Through self-attention, [CLS] can β€œcollect” information from all tokens:

\[h_{[\text{CLS}]} = \text{Attention}(q_{[\text{CLS}]}, K_{\text{all}}, V_{\text{all}})\]

The [CLS] token’s representation summarizes the entire sequence!

[SEP] - Separator Token:

  • Marks sentence boundaries

  • Helps model distinguish between segments

  • Essential for NSP task

[MASK] - Mask Token:

  • Used only during pre-training

  • Replaced with actual tokens during fine-tuning

[PAD] - Padding Token:

  • Pads sequences to same length in batch

  • Ignored via attention mask

Attention MaskingΒΆ

Padding mask prevents attention to [PAD]:

\[\begin{split}\text{mask}_{ij} = \begin{cases} 0 & \text{if } j \text{ is not [PAD]} \\ -\infty & \text{if } j \text{ is [PAD]} \end{cases}\end{split}\]

After softmax, \(\exp(-\infty) = 0\) β†’ no attention to padding.

Full attention mask pattern (BERT):

        [CLS] The  cat  [SEP] mat [PAD]
[CLS]     βœ“    βœ“    βœ“     βœ“    βœ“    βœ—
The       βœ“    βœ“    βœ“     βœ“    βœ“    βœ—  
cat       βœ“    βœ“    βœ“     βœ“    βœ“    βœ—
[SEP]     βœ“    βœ“    βœ“     βœ“    βœ“    βœ—
mat       βœ“    βœ“    βœ“     βœ“    βœ“    βœ—
[PAD]     βœ—    βœ—    βœ—     βœ—    βœ—    βœ—

βœ“ = Can attend (bidirectional)
βœ— = Cannot attend (masked)

Compare to GPT (causal):

        The  cat  sat  on
The      βœ“    βœ—    βœ—    βœ—
cat      βœ“    βœ“    βœ—    βœ—  
sat      βœ“    βœ“    βœ“    βœ—
on       βœ“    βœ“    βœ“    βœ“

(Lower triangular = causal)

Position Embeddings: Absolute vs RelativeΒΆ

BERT: Absolute Learned Positions

\[p_i \in \mathbb{R}^d \quad \text{for } i = 0, 1, \ldots, 511\]

Each position gets its own learned embedding.

Advantages:

  • Simple implementation

  • Can learn task-specific patterns

Disadvantages:

  • Fixed max length (512 for BERT)

  • No extrapolation beyond training length

  • Doesn’t capture relative distance

DeBERTa: Relative Position Bias

Add relative position bias to attention scores:

\[A_{ij} = \frac{(x_i W_Q)(x_j W_K)^T}{\sqrt{d}} + b_{i-j}\]

where \(b_{i-j}\) encodes relative distance.

Advantages: βœ… Better length extrapolation βœ… Captures relative relationships βœ… More parameter efficient

BERT vs RoBERTa vs DeBERTa:

Feature

BERT

RoBERTa

DeBERTa

Position

Absolute learned

Absolute learned

Relative bias

Max length

512

512

Unlimited

NSP

Yes

No

No

Masking

Static

Dynamic

Dynamic

Vocab

30K WordPiece

50K BPE

128K SentencePiece

Input Construction ExamplesΒΆ

Single sentence classification:

[CLS] This movie was amazing [SEP]
  0     0     0     0    0      0

Sentence pair (NLI, QA):

[CLS] Is it raining? [SEP] Yes, take umbrella [SEP]
  0     0  0    0      0     1    1     1       1

Token classification (NER):

[CLS] John lives in Paris [SEP]
  0     0     0    0   0     0

Output labels:
  O     B-PER  O   O  B-LOC  O

Fine-tuning Architecture PatternsΒΆ

Sequence classification:

Input β†’ BERT β†’ h[CLS] β†’ Linear β†’ Softmax

Token classification:

Input β†’ BERT β†’ h_i for all i β†’ Linear β†’ Softmax per token

Span extraction (QA):

Input β†’ BERT β†’ h_i β†’ Start logits
              β†’ h_i β†’ End logits
              
Answer = tokens from argmax(start) to argmax(end)

Sentence pair scoring:

Input β†’ BERT β†’ h[CLS] β†’ Linear β†’ Sigmoid (similarity score)

Computational ComplexityΒΆ

Embedding layer:

  • Token: \(O(V \cdot d)\) parameters where \(V\) = vocab size

  • Position: \(O(L \cdot d)\) where \(L\) = max length

  • Segment: \(O(2 \cdot d)\) (just 2 segments)

Total embedding: ~30-50M parameters for BERT-base

Self-attention dominates:

  • Per layer: \(O(L^2 \cdot d)\) time complexity

  • Total layers: 12 (base) or 24 (large)

  • Quadratic in sequence length!

class BERTEmbeddings(nn.Module):
    """BERT input embeddings: token + position + segment."""
    def __init__(self, vocab_size, hidden_size, max_position, dropout=0.1):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)  # Two segments: A and B
        
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_ids, segment_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        
        if segment_ids is None:
            segment_ids = torch.zeros_like(input_ids)
        
        # Sum embeddings
        embeddings = self.token_embeddings(input_ids)
        embeddings += self.position_embeddings(position_ids)
        embeddings += self.segment_embeddings(segment_ids)
        
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings

class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention."""
    def __init__(self, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
        
        self.dropout = nn.Dropout(dropout)
        self.output = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x, attention_mask=None):
        batch_size, seq_length, hidden_size = x.size()
        
        # Linear projections
        Q = self.query(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores + attention_mask
        
        attn_probs = F.softmax(scores, dim=-1)
        attn_probs = self.dropout(attn_probs)
        
        context = torch.matmul(attn_probs, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
        
        output = self.output(context)
        return output

class FeedForward(nn.Module):
    """Position-wise feed-forward network."""
    def __init__(self, hidden_size, intermediate_size, dropout=0.1):
        super().__init__()
        self.dense1 = nn.Linear(hidden_size, intermediate_size)
        self.dense2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.dense1(x)
        x = F.gelu(x)  # BERT uses GELU
        x = self.dropout(x)
        x = self.dense2(x)
        return x

class BERTLayer(nn.Module):
    """Single BERT transformer layer."""
    def __init__(self, hidden_size, num_heads, intermediate_size, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadSelfAttention(hidden_size, num_heads, dropout)
        self.attn_layer_norm = nn.LayerNorm(hidden_size)
        self.ffn = FeedForward(hidden_size, intermediate_size, dropout)
        self.ffn_layer_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, attention_mask=None):
        # Self-attention
        attn_output = self.attention(x, attention_mask)
        x = self.attn_layer_norm(x + self.dropout(attn_output))
        
        # Feed-forward
        ffn_output = self.ffn(x)
        x = self.ffn_layer_norm(x + self.dropout(ffn_output))
        
        return x
class BERT(nn.Module):
    """BERT model."""
    def __init__(self, vocab_size, hidden_size=768, num_layers=12, 
                 num_heads=12, intermediate_size=3072, max_position=512, dropout=0.1):
        super().__init__()
        
        self.embeddings = BERTEmbeddings(vocab_size, hidden_size, max_position, dropout)
        
        self.encoder_layers = nn.ModuleList([
            BERTLayer(hidden_size, num_heads, intermediate_size, dropout)
            for _ in range(num_layers)
        ])
        
        # Pre-training heads
        self.mlm_head = nn.Linear(hidden_size, vocab_size)
        self.nsp_head = nn.Linear(hidden_size, 2)
    
    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        # Embeddings
        x = self.embeddings(input_ids, segment_ids)
        
        # Prepare attention mask
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = (1.0 - attention_mask) * -10000.0
        
        # Transformer layers
        for layer in self.encoder_layers:
            x = layer(x, attention_mask)
        
        return x
    
    def get_mlm_logits(self, sequence_output):
        """Masked language modeling predictions."""
        return self.mlm_head(sequence_output)
    
    def get_nsp_logits(self, sequence_output):
        """Next sentence prediction."""
        # Use [CLS] token (first token)
        cls_output = sequence_output[:, 0, :]
        return self.nsp_head(cls_output)

# Create small BERT
vocab_size = 1000
model = BERT(vocab_size=vocab_size, hidden_size=128, num_layers=2, 
             num_heads=4, intermediate_size=512, max_position=128).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Masked Language ModelingΒΆ

BERT’s primary pre-training objective is Masked Language Modeling (MLM): randomly mask 15% of input tokens and train the model to predict them from their bidirectional context. Of the selected tokens, 80% are replaced with [MASK], 10% with a random token, and 10% are left unchanged – this mixed strategy prevents the model from relying on the [MASK] token as a signal and improves robustness. The MLM loss is simply cross-entropy over the masked positions: \(\mathcal{L}_{\text{MLM}} = -\sum_{i \in \text{masked}} \log p(x_i | x_{\backslash i})\). This self-supervised objective forces the model to build rich, contextual representations of language that transfer powerfully to downstream tasks via fine-tuning.

3.5. BERT Training: Advanced TechniquesΒΆ

Pre-training at ScaleΒΆ

Original BERT setup:

  • Data: BookCorpus (800M words) + Wikipedia (2.5B words)

  • Batch size: 256 sequences

  • Steps: 1M steps

  • Time: 4 days on 16 Cloud TPUs (64 chips)

  • Cost: ~$7,000

RoBERTa improvements:

  • Data: 160GB of text (10Γ— more)

  • Batch size: 8K sequences

  • Steps: 500K steps

  • Time: 1 day on 1024 V100 GPUs

  • Result: +2-3% on GLUE benchmark

Optimization StrategyΒΆ

BERT optimizer: Adam with warmup

\[\text{lr}(t) = \text{lr}_{\max} \cdot \min\left(t^{-0.5}, t \cdot \text{warmup}^{-1.5}\right)\]

Schedule:

  1. Warmup: Linear increase from 0 to max LR

    • Typically 10K steps

    • Stabilizes training early on

  2. Decay: Linear decrease to 0

    • Or polynomial decay

    • Prevents overfitting

Why warmup?

  • Large batch sizes β†’ noisy gradients early

  • Warmup stabilizes optimization

  • Prevents divergence from random init

Learning Rate SensitivityΒΆ

Task

Pre-training LR

Fine-tuning LR

Pre-train

1e-4

N/A

GLUE

N/A

2e-5 to 5e-5

SQuAD

N/A

3e-5

NER

N/A

5e-5

Rule of thumb:

  • Pre-training: Higher LR (1e-4)

  • Fine-tuning: Lower LR (2e-5 to 5e-5)

  • Smaller tasks β†’ lower LR

Layer-wise Learning Rate Decay (LLRD)ΒΆ

Problem: All layers learn at same rate

Solution: Different LR for different layers

\[\text{lr}_{\ell} = \text{lr}_{\text{top}} \cdot \alpha^{L - \ell}\]

where:

  • \(\ell\) is layer index

  • \(L\) is total layers

  • \(\alpha \in [0.9, 0.95]\) is decay factor

Intuition:

  • Lower layers: More general features

  • Upper layers: Task-specific features

  • Lower layers need less updating during fine-tuning

Example (BERT-base, 12 layers, \(\alpha=0.95\)):

Layer 12 (output): lr = 2e-5
Layer 11: lr = 2e-5 Γ— 0.95 = 1.9e-5
Layer 10: lr = 2e-5 Γ— 0.95Β² = 1.8e-5
...
Layer 1: lr = 2e-5 Γ— 0.95ΒΉΒΉ = 1.1e-5
Embeddings: lr = 2e-5 Γ— 0.95ΒΉΒ² = 1.0e-5

Impact: +0.5-1% on downstream tasks!

Gradient AccumulationΒΆ

Problem: Large batch sizes don’t fit in GPU memory

Solution: Accumulate gradients over mini-batches

accumulation_steps = 4
optimizer.zero_grad()

for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Effective batch size: batch_size Γ— accumulation_steps

Example:

  • Actual batch: 16 (fits in GPU)

  • Accumulation: 16 steps

  • Effective batch: 256

Mixed Precision TrainingΒΆ

Use FP16 instead of FP32:

\[\text{Memory} = \frac{\text{FP32 Memory}}{2}\]
\[\text{Speed} \approx 2-3\times \text{ faster}\]

Implementation:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    with autocast():  # FP16 forward pass
        loss = model(batch)
    
    scaler.scale(loss).backward()  # FP16 backward
    scaler.step(optimizer)         # FP32 update
    scaler.update()

Loss scaling prevents underflow in FP16.

Knowledge Distillation for BERTΒΆ

Compress large BERT β†’ small BERT

DistilBERT approach:

  1. Teacher: BERT-base (110M params)

  2. Student: DistilBERT (66M params, 6 layers)

Distillation loss:

\[\mathcal{L} = \alpha \mathcal{L}_{\text{CE}}(y, \hat{y}) + (1-\alpha) \mathcal{L}_{\text{KL}}(\sigma(z_s/T), \sigma(z_t/T))\]

where:

  • \(\mathcal{L}_{\text{CE}}\): Cross-entropy with true labels

  • \(\mathcal{L}_{\text{KL}}\): KL divergence between student/teacher logits

  • \(T\): Temperature (typically 2-4)

  • \(\alpha\): Weight (typically 0.5)

Results:

  • DistilBERT: 97% of BERT performance

  • 40% fewer parameters

  • 60% faster inference

TinyBERT: Even more aggressive

  • 4 layers, 312 hidden size

  • 13.3M parameters (7.5Γ— smaller)

  • 95% of BERT performance

Continual Pre-trainingΒΆ

Adapt to specific domain:

  1. Base pre-training: General corpus

  2. Domain pre-training: Domain-specific corpus

  3. Fine-tuning: Task-specific data

Example: BioBERT

BERT β†’ PubMed + PMC (biomedical) β†’ NER/QA tasks

Don’t Forget! (Catastrophic forgetting)

When continually pre-training:

  • Mix in general domain data (10-20%)

  • Use lower learning rate

  • Shorter training

Practical Training TipsΒΆ

Data preprocessing: βœ… Remove duplicates (RoBERTa found 30% duplicates!) βœ… Filter quality (perplexity, length) βœ… Balance domains

Hyperparameters: βœ… Warmup: 10% of total steps βœ… Max LR: 1e-4 for pre-training βœ… Batch size: As large as possible βœ… Gradient clipping: 1.0

Monitoring: βœ… MLM accuracy (should reach ~60-70%) βœ… Perplexity on validation set βœ… Loss curve (should be smooth)

Early stopping: βœ… Validation loss plateaus βœ… Downstream task performance βœ… Typically 100K-1M steps sufficient

Computing RequirementsΒΆ

BERT-base from scratch:

  • 16 TPU chips Γ— 4 days

  • Or ~256 GPU days

  • Or ~$5,000-10,000

Fine-tuning on single task:

  • 1 GPU Γ— 1-4 hours

  • $1-5 on cloud

Recommendation:

  • Use pre-trained models when possible!

  • Only pre-train if domain very different

  • Fine-tuning is usually sufficient

def create_masked_lm_data(input_ids, mask_token_id=103, mask_prob=0.15):
    """Create MLM training data.
    
    For 15% of tokens:
    - 80%: Replace with [MASK]
    - 10%: Replace with random token
    - 10%: Keep unchanged
    """
    labels = input_ids.clone()
    masked_input = input_ids.clone()
    
    # Create mask
    probability_matrix = torch.full(input_ids.shape, mask_prob)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    
    # Don't mask special tokens (assume 0-2 are special)
    special_tokens_mask = input_ids < 3
    masked_indices &= ~special_tokens_mask
    
    # Set labels: -100 for unmasked positions
    labels[~masked_indices] = -100
    
    # 80% mask
    indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
    masked_input[indices_replaced] = mask_token_id
    
    # 10% random
    indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long)
    masked_input[indices_random] = random_words[indices_random]
    
    # 10% unchanged (remaining masked_indices)
    
    return masked_input, labels

# Example
input_ids = torch.randint(3, vocab_size, (4, 20))  # Batch of 4, seq len 20
masked_input, labels = create_masked_lm_data(input_ids)

print("Original:", input_ids[0, :10].tolist())
print("Masked:  ", masked_input[0, :10].tolist())
print("Labels:  ", labels[0, :10].tolist())
# Training loop (simplified)
def train_mlm(model, data_loader, epochs=3):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    model.train()
    
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        
        for batch_idx, input_ids in enumerate(data_loader):
            input_ids = input_ids.to(device)
            
            # Create masked data
            masked_input, labels = create_masked_lm_data(input_ids)
            masked_input = masked_input.to(device)
            labels = labels.to(device)
            
            # Forward
            outputs = model(masked_input)
            logits = model.get_mlm_logits(outputs)
            
            # Loss
            loss = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1), ignore_index=-100)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        avg_loss = epoch_loss / len(data_loader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1} avg loss: {avg_loss:.4f}")
    
    return losses

# Create dummy data
dummy_data = torch.randint(3, vocab_size, (100, 32))  # 100 sequences
data_loader = torch.utils.data.DataLoader(dummy_data, batch_size=8, shuffle=True)

print("Training BERT with MLM...")
losses = train_mlm(model, data_loader, epochs=2)

# Plot
plt.figure(figsize=(8, 5))
plt.plot(losses, marker='o', linewidth=2)
plt.xlabel('Epoch', fontsize=11)
plt.ylabel('MLM Loss', fontsize=11)
plt.title('BERT Pre-training', fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

SummaryΒΆ

BERT Architecture:ΒΆ

Embeddings: Token + Position + Segment

Encoder: Stack of transformer layers

  • Multi-head self-attention

  • Position-wise FFN

  • Layer normalization + residual

Pre-training:ΒΆ

  1. MLM: Predict masked tokens (15%)

  2. NSP: Binary classification for sentence pairs

Key Innovations:ΒΆ

  • Bidirectional: Full context

  • Deep: 12-24 layers

  • Pre-train then fine-tune: Transfer learning

BERT Variants:ΒΆ

  • RoBERTa: Remove NSP, larger batches

  • ALBERT: Parameter sharing, factorized embeddings

  • DistilBERT: Knowledge distillation (smaller)

  • ELECTRA: Replaced token detection

Applications:ΒΆ

  • Sentiment analysis

  • Named entity recognition

  • Question answering (SQuAD)

  • Text classification

Next Steps:ΒΆ

  • 06_vision_transformers.ipynb - ViT architecture

  • Explore GPT for generation tasks

  • Fine-tune BERT on specific datasets