Lab 02: Kernel ValidationΒΆ

GEMM Β· Convolution Β· Attention Β· Softmax Β· LayerNormΒΆ

Role alignment: AMD Principal Staff – AI/ML Performance Validation
Reference: 02_kernel_validation.ipynb

What you will do:ΒΆ

  1. Validate GEMM correctness across dtypes (FP32, FP16, BF16) with proper tolerances

  2. Test softmax and layernorm edge cases

  3. Validate scaled dot-product attention (FlashAttention-style)

  4. Measure kernel efficiency vs roofline model

  5. Build a reusable KernelValidator class

The key insight: GPU results are NOT bit-identical to CPU β€” you must use dtype-aware tolerances

SetupΒΆ

The setup establishes the dtype-aware tolerance table that governs all kernel validation in this lab: FP32 uses atol=1e-4, FP16 uses atol=1e-2, and BF16 uses atol=1e-1. These tolerances reflect the fundamental precision limits of each data type – FP16 has only 10 mantissa bits (~3.3 decimal digits), so expecting better than 1e-2 absolute agreement with an FP32 reference is unrealistic. The tols() helper function provides a clean interface for looking up tolerances by dtype, establishing a pattern used throughout production validation code at AMD and NVIDIA.

import torch
import torch.nn.functional as F
import time
import math
from dataclasses import dataclass
from typing import Optional

if torch.cuda.is_available():
    DEVICE = 'cuda'
elif torch.backends.mps.is_available():
    DEVICE = 'mps'
else:
    DEVICE = 'cpu'

print(f'Device: {DEVICE}')
if DEVICE == 'cuda':
    print(f'GPU   : {torch.cuda.get_device_name(0)}')

# ---- Tolerance table (AMD/NVIDIA engineering standard) ----
# Lower precision = higher tolerance needed
TOLERANCES = {
    torch.float64:  {'atol': 1e-7,  'rtol': 1e-5},
    torch.float32:  {'atol': 1e-4,  'rtol': 1e-4},
    torch.float16:  {'atol': 1e-2,  'rtol': 1e-3},
    torch.bfloat16: {'atol': 1e-1,  'rtol': 1e-2},
}

def tols(dtype):
    return TOLERANCES.get(dtype, {'atol': 1e-2, 'rtol': 1e-2})

Concept: Why Tolerances MatterΒΆ

FP32  β†’ 7 decimal digits of precision  β†’ atol=1e-4
FP16  β†’ 3 decimal digits of precision  β†’ atol=1e-2
BF16  β†’ 2-3 decimal digits of precision β†’ atol=1e-1
INT8  β†’ integer quantization error     β†’ much higher

GPU kernels reorder floating-point operations (non-associative).
The reference is FP32 on CPU β€” GPU result must be within tolerance, not bit-identical.

Exercise 2.1 – GEMM Correctness SweepΒΆ

Validate GPU matrix multiplication against CPU FP32 reference across:

  • Shapes: square (128β†’4096), non-square, tall, wide

  • Dtypes: FP32, FP16, BF16

Pass criteria: max_abs_error ≀ atol and max_rel_error ≀ rtol

@dataclass
class ValidationResult:
    name: str
    passed: bool
    max_abs_error: float
    max_rel_error: float
    atol: float
    rtol: float
    notes: str = ''

    def __str__(self):
        status = 'PASS' if self.passed else 'FAIL'
        return (f'[{status}] {self.name:50s}  '
                f'abs={self.max_abs_error:.2e} (≀{self.atol:.1e})  '
                f'rel={self.max_rel_error:.2e} (≀{self.rtol:.1e})'
                + (f'  NOTE: {self.notes}' if self.notes else ''))


def validate_gemm(M, K, N, dtype, device=DEVICE):
    """
    Compare GPU GEMM result to FP32 CPU reference.
    """
    t = tols(dtype)

    # Reference: FP32 on CPU
    a_cpu = torch.randn(M, K, dtype=torch.float32)
    b_cpu = torch.randn(K, N, dtype=torch.float32)
    ref = torch.matmul(a_cpu, b_cpu)

    # GPU result in target dtype
    a_gpu = a_cpu.to(device=device, dtype=dtype)
    b_gpu = b_cpu.to(device=device, dtype=dtype)
    out_gpu = torch.matmul(a_gpu, b_gpu).to(dtype=torch.float32).cpu()

    abs_err = (out_gpu - ref).abs()
    rel_err = abs_err / (ref.abs() + 1e-8)

    max_abs = abs_err.max().item()
    max_rel = rel_err.max().item()
    passed = (max_abs <= t['atol']) and (max_rel <= t['rtol'])

    name = f'GEMM({M}x{K}x{N}) {str(dtype).split(".")[1]}'
    return ValidationResult(name, passed, max_abs, max_rel, t['atol'], t['rtol'])


# Sweep
results = []
shapes = [(128, 128, 128), (512, 512, 512), (1024, 1024, 1024),
          (2048, 2048, 2048), (1024, 4096, 1024), (256, 1024, 8192)]
dtypes = [torch.float32, torch.float16]
if DEVICE == 'cuda':
    dtypes.append(torch.bfloat16)

# Reduce on CPU
if DEVICE == 'cpu':
    shapes = [(128, 128, 128), (256, 256, 256), (512, 256, 512)]
    dtypes = [torch.float32]

for dtype in dtypes:
    for M, K, N in shapes:
        r = validate_gemm(M, K, N, dtype)
        results.append(r)
        print(r)

passed = sum(1 for r in results if r.passed)
print(f'\n{passed}/{len(results)} GEMM tests passed')

Exercise 2.2 – Softmax Edge CasesΒΆ

Softmax is numerically tricky. Test edge cases that expose GPU kernel bugs:

  • All zeros

  • All same value

  • One very large value (should dominate)

  • Mixed large + small (overflow risk in naive implementation)

def validate_softmax(input_cpu, name, expect_numerical_issue=False):
    """Compare GPU softmax to CPU FP32 reference.
    
    Args:
        input_cpu: Input tensor on CPU.
        name: Human-readable test name.
        expect_numerical_issue: If True, the test PASSES when NaN/Inf is
            detected in the output (i.e. we successfully caught instability).
    """
    ref = F.softmax(input_cpu, dim=-1)

    gpu_input = input_cpu.to(DEVICE)
    out = F.softmax(gpu_input, dim=-1).cpu()

    # Softmax properties:
    # 1. Output sums to 1
    # 2. All values in [0, 1]
    # 3. Close to CPU reference
    sum_check = abs(out.sum().item() - 1.0) < 1e-4
    range_check = (out >= 0).all() and (out <= 1).all()
    has_nan = torch.isnan(out).any().item()
    has_inf = torch.isinf(out).any().item()
    close = torch.allclose(out.float(), ref.float(), atol=1e-4)

    issues = []
    if not sum_check: issues.append(f'sum={out.sum().item():.6f} (expected 1.0)')
    if not range_check: issues.append('values outside [0,1]')
    if has_nan: issues.append('NaN detected')
    if has_inf: issues.append('Inf detected')
    if not close: issues.append(f'max_diff={(out.float()-ref.float()).abs().max().item():.2e}')

    numerically_stable = (not has_nan) and (not has_inf)

    if expect_numerical_issue:
        # We EXPECT instability here β€” pass means we correctly detected the problem
        passed = not numerically_stable  # True when NaN or Inf is present
        if passed:
            issues.append('EXPECTED: numerical instability correctly detected')
        else:
            issues.append('UNEXPECTED: no instability detected for pathological input')
    else:
        passed = sum_check and range_check and numerically_stable
    
    status = 'PASS' if passed else 'FAIL'
    print(f'[{status}] {name:45s}  sum_ok={sum_check}  range_ok={range_check}  '
          f'nan={has_nan}  {", ".join(issues) if issues else "all checks pass"}')
    return passed


SEQ = 2048
test_cases = [
    (torch.zeros(SEQ),                          'All zeros',                  False),
    (torch.ones(SEQ),                           'All ones (same value)',       False),
    (torch.full((SEQ,), 1000.0),                'All large same value',       False),
    (torch.cat([torch.tensor([1000.0]), torch.zeros(SEQ-1)]), 'One large, rest zeros', False),
    (torch.cat([torch.tensor([1000.0, -1000.0]), torch.randn(SEQ-2)]), 'Large + small mix', False),
    (torch.randn(SEQ) * 10,                     'High variance random',       False),
    (torch.full((SEQ,), float('inf')),          'All inf (must produce NaN)', True),
]

print('Softmax edge case validation:')
print('-' * 90)
for tensor, name, expect_issue in test_cases:
    try:
        validate_softmax(tensor, name, expect_numerical_issue=expect_issue)
    except Exception as e:
        print(f'[EXCEPTION] {name}: {e}')

Exercise 2.3 – LayerNorm ValidationΒΆ

LayerNorm is in every Transformer. Test:

  • Basic correctness

  • Near-zero variance (numerical instability risk)

  • Large hidden dimensions

def validate_layernorm(batch, seq, hidden, name, eps=1e-5):
    """Compare GPU LayerNorm to manual FP32 CPU computation."""
    x_cpu = torch.randn(batch, seq, hidden, dtype=torch.float32)
    weight = torch.ones(hidden)
    bias = torch.zeros(hidden)

    # Reference: manual FP32
    mean = x_cpu.mean(dim=-1, keepdim=True)
    var  = x_cpu.var(dim=-1, keepdim=True, unbiased=False)
    ref  = (x_cpu - mean) / (var + eps).sqrt() * weight + bias

    # GPU via torch.nn.functional
    x_gpu  = x_cpu.to(DEVICE)
    w_gpu  = weight.to(DEVICE)
    b_gpu  = bias.to(DEVICE)
    out    = F.layer_norm(x_gpu, [hidden], w_gpu, b_gpu, eps=eps).cpu()

    max_err = (out - ref).abs().max().item()
    passed  = max_err < 1e-3
    status  = 'PASS' if passed else 'FAIL'
    print(f'[{status}] {name:45s}  max_abs_err={max_err:.2e}')
    return passed


print('LayerNorm validation:')
print('-' * 70)
validate_layernorm(2, 128,  768,  'BERT-base (batch=2, seq=128, h=768)')
validate_layernorm(1, 2048, 4096, 'LLM-style (batch=1, seq=2048, h=4096)')
validate_layernorm(8, 512,  256,  'Small hidden (h=256)')
validate_layernorm(1, 1,    8192, 'Large hidden only (h=8192)')

Exercise 2.4 – Scaled Dot-Product Attention ValidationΒΆ

Test the attention kernel used in every LLM:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

Compare PyTorch’s optimized SDPA (which uses FlashAttention on CUDA) vs manual reference.

def manual_attention(Q, K, V, mask=None):
    """Reference implementation in FP32."""
    d_k = Q.shape[-1]
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attn_weights, V), attn_weights


def validate_attention(batch, heads, seq_len, head_dim, dtype, use_causal=False):
    t = tols(dtype)
    name = f'Attention(B={batch},H={heads},S={seq_len},D={head_dim}) {str(dtype).split(".")[1]}'
    if use_causal:
        name += ' causal'

    # Reference: FP32 CPU manual
    Q_cpu = torch.randn(batch, heads, seq_len, head_dim)
    K_cpu = torch.randn(batch, heads, seq_len, head_dim)
    V_cpu = torch.randn(batch, heads, seq_len, head_dim)
    ref_out, _ = manual_attention(Q_cpu, K_cpu, V_cpu)

    # GPU SDPA (uses FlashAttention on CUDA >= 8.0)
    Q_gpu = Q_cpu.to(device=DEVICE, dtype=dtype)
    K_gpu = K_cpu.to(device=DEVICE, dtype=dtype)
    V_gpu = V_cpu.to(device=DEVICE, dtype=dtype)

    with torch.no_grad():
        out_gpu = F.scaled_dot_product_attention(
            Q_gpu, K_gpu, V_gpu,
            is_causal=use_causal
        ).to(torch.float32).cpu()

    # Compare (causal mask shifts result, only compare first half)
    compare_ref = ref_out
    compare_out = out_gpu
    if use_causal:
        # First token = same (only attends to itself)
        # Use looser tolerance for causal since mask changes the computation
        compare_ref = ref_out[:, :, :1, :]
        compare_out = out_gpu[:, :, :1, :]

    max_err = (compare_out - compare_ref).abs().max().item()
    passed  = max_err <= t['atol'] * 10  # slightly looser for attention due to softmax chain
    no_nan  = not torch.isnan(out_gpu).any()
    no_inf  = not torch.isinf(out_gpu).any()
    overall = passed and no_nan and no_inf

    status = 'PASS' if overall else 'FAIL'
    notes = []
    if not passed: notes.append(f'max_err={max_err:.2e}>atol={t["atol"]*10:.1e}')
    if not no_nan: notes.append('NaN')
    if not no_inf: notes.append('Inf')
    print(f'[{status}] {name:60s}  max_err={max_err:.2e}  {" ".join(notes)}')
    return overall


print('Attention kernel validation:')
print('-' * 90)

attn_dtypes = [torch.float32, torch.float16]
if DEVICE == 'cuda':
    attn_dtypes.append(torch.bfloat16)
if DEVICE == 'cpu':
    attn_dtypes = [torch.float32]

configs = [
    (1, 8,  128,  64,  False),   # BERT-style
    (1, 32, 512,  128, False),   # LLM decode
    (1, 16, 2048, 64,  True),    # Causal LLM
]
if DEVICE == 'cpu':
    configs = [(1, 4, 64, 32, False), (1, 4, 64, 32, True)]

all_results = []
for dtype in attn_dtypes:
    for B, H, S, D, causal in configs:
        try:
            r = validate_attention(B, H, S, D, dtype, causal)
            all_results.append(r)
        except Exception as e:
            print(f'[ERROR] {e}')

print(f'\n{sum(all_results)}/{len(all_results)} attention tests passed')

Exercise 2.5 – Sequence Length ScalingΒΆ

Test numerical error growth as sequence length increases.
Long sequences accumulate floating-point error β€” this is a real issue at seq_len > 8192.

seq_lengths = [64, 128, 256, 512, 1024, 2048] if DEVICE != 'cpu' else [32, 64, 128]
dtype = torch.float16 if DEVICE == 'cuda' else torch.float32

print(f'Attention error vs sequence length ({str(dtype).split(".")[1]}):')
print(f"{'Seq Len':>10}  {'Max Abs Error':>15}  {'Max Rel Error':>15}  Status")
print('-' * 60)

scaling_data = []

for seq_len in seq_lengths:
    B, H, D = 1, 8, 64
    Q_cpu = torch.randn(B, H, seq_len, D)
    K_cpu = torch.randn(B, H, seq_len, D)
    V_cpu = torch.randn(B, H, seq_len, D)

    ref, _ = manual_attention(Q_cpu, K_cpu, V_cpu)

    Q_g = Q_cpu.to(DEVICE, dtype=dtype)
    K_g = K_cpu.to(DEVICE, dtype=dtype)
    V_g = V_cpu.to(DEVICE, dtype=dtype)
    out = F.scaled_dot_product_attention(Q_g, K_g, V_g).float().cpu()

    abs_err = (out - ref).abs()
    rel_err = abs_err / (ref.abs() + 1e-8)
    max_abs = abs_err.max().item()
    max_rel = rel_err.max().item()
    ok = max_abs < 0.1
    scaling_data.append((seq_len, max_abs, max_rel))
    status = 'PASS' if ok else 'WARN'
    print(f'{seq_len:>10}  {max_abs:>15.4e}  {max_rel:>15.4e}  {status}')

print('\nNote: Error growing with seq_len is expected for FP16/BF16 β€” watch for sudden jumps')

Exercise 2.6 – GEMM Performance vs RooflineΒΆ

Measure achieved FLOPS for GEMM and compare to theoretical GPU peak.
This tells you whether a kernel is compute-bound or memory-bound.

def measure_gemm_tflops(M, K, N, dtype=torch.float16, iterations=100):
    """
    Measure GEMM throughput in TFLOPS.
    GEMM FLOPs = 2 * M * K * N (multiply + add for each element)
    """
    if DEVICE == 'cpu':
        dtype = torch.float32
        iterations = 10

    a = torch.randn(M, K, device=DEVICE, dtype=dtype)
    b = torch.randn(K, N, device=DEVICE, dtype=dtype)

    # Warmup
    for _ in range(10):
        _ = torch.matmul(a, b)
    if DEVICE == 'cuda':
        torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iterations):
        _ = torch.matmul(a, b)
    if DEVICE == 'cuda':
        torch.cuda.synchronize()
    elapsed = time.perf_counter() - t0

    flops_per_iter = 2 * M * K * N
    total_flops    = flops_per_iter * iterations
    tflops         = total_flops / elapsed / 1e12

    del a, b
    if DEVICE == 'cuda':
        torch.cuda.empty_cache()
    return tflops


# Theoretical peak TFLOPS (FP16) for common GPUs
THEORETICAL_TFLOPS = {
    'NVIDIA H100':        989.4,
    'NVIDIA A100-SXM':    312.0,
    'NVIDIA RTX 4090':    165.2,
    'NVIDIA RTX 3090':     71.0,
    'AMD MI300X':         1307.4,
    'AMD MI250X':          383.0,
}

sizes = [(4096, 4096, 4096), (8192, 8192, 8192)] if DEVICE != 'cpu' else [(256, 256, 256)]
dtype = torch.float16 if DEVICE == 'cuda' else torch.float32

print(f'GEMM Performance ({str(dtype).split(".")[1]}):')
print(f"{'Shape':>25}  {'TFLOPS':>10}  {'% of Peak':>12}")
print('-' * 55)

if DEVICE == 'cuda':
    gpu_name = torch.cuda.get_device_name(0)
    peak = next((v for k, v in THEORETICAL_TFLOPS.items() if k in gpu_name), 100.0)
    print(f'GPU: {gpu_name}  |  Theoretical FP16 peak: {peak} TFLOPS\n')
else:
    peak = 1.0

for M, K, N in sizes:
    achieved = measure_gemm_tflops(M, K, N, dtype)
    pct = achieved / peak * 100 if DEVICE == 'cuda' else 0
    status = 'PASS' if pct >= 60 or DEVICE != 'cuda' else 'WARN'
    print(f'{f"({M},{K},{N})":>25}  {achieved:>10.1f}  {pct:>11.1f}%  [{status}]')

print('\nRule of thumb: β‰₯60% of peak = efficient kernel; <40% = likely memory-bound or misconfigured')

Exercise 2.7 – KernelValidator Class (Reusable)ΒΆ

Build a reusable test harness β€” this is what AMD validation teams actually build and maintain.

import json
from pathlib import Path
from datetime import datetime

class KernelValidator:
    """
    Reusable kernel correctness + performance validator.
    Stores results and produces a structured report.
    """

    def __init__(self, device=DEVICE):
        self.device = device
        self.results = []

    def validate(
        self,
        kernel_fn,          # GPU kernel to test
        reference_fn,       # CPU FP32 reference
        inputs,             # dict of {name: tensor} on CPU FP32
        name: str,
        dtype=torch.float32,
        atol: Optional[float] = None,
        rtol: Optional[float] = None,
    ):
        t = tols(dtype)
        atol = atol or t['atol']
        rtol = rtol or t['rtol']

        # CPU reference
        with torch.no_grad():
            ref = reference_fn(**inputs)

        # GPU execution
        gpu_inputs = {k: v.to(device=self.device, dtype=dtype) for k, v in inputs.items()}
        with torch.no_grad():
            out = kernel_fn(**gpu_inputs)
            if isinstance(out, tuple):
                out = out[0]  # take first output
        out = out.to(dtype=torch.float32).cpu()

        # Metrics
        abs_err = (out - ref).abs()
        max_abs = abs_err.max().item()
        max_rel = (abs_err / (ref.abs() + 1e-8)).max().item()
        no_nan  = not torch.isnan(out).any().item()
        no_inf  = not torch.isinf(out).any().item()
        passed  = (max_abs <= atol) and (max_rel <= rtol) and no_nan and no_inf

        result = {
            'name': name,
            'dtype': str(dtype).split('.')[-1],
            'passed': passed,
            'max_abs_error': round(max_abs, 8),
            'max_rel_error': round(max_rel, 8),
            'atol': atol,
            'rtol': rtol,
            'nan': not no_nan,
            'inf': not no_inf,
        }
        self.results.append(result)
        status = 'PASS' if passed else 'FAIL'
        print(f'[{status}] {name:50s}  abs={max_abs:.2e}  rel={max_rel:.2e}')
        return passed

    def report(self, path='lab02_kernel_report.json'):
        summary = {
            'generated_at': datetime.now().isoformat(),
            'device': self.device,
            'gpu': torch.cuda.get_device_name(0) if self.device == 'cuda' else self.device,
            'total': len(self.results),
            'passed': sum(1 for r in self.results if r['passed']),
            'failed': sum(1 for r in self.results if not r['passed']),
            'tests': self.results,
        }
        Path(path).write_text(json.dumps(summary, indent=2))
        print(f'\nReport: {summary["passed"]}/{summary["total"]} passed β†’ {path}')
        return summary


# ---- Demo: use KernelValidator ----
validator = KernelValidator()

size = 512 if DEVICE != 'cpu' else 128
dtype = torch.float16 if DEVICE == 'cuda' else torch.float32

# Test 1: GEMM
validator.validate(
    kernel_fn   = lambda a, b: torch.matmul(a, b),
    reference_fn= lambda a, b: torch.matmul(a, b),
    inputs      = {'a': torch.randn(size, size), 'b': torch.randn(size, size)},
    name        = f'GEMM({size}x{size})',
    dtype       = dtype,
)

# Test 2: LayerNorm
hidden = 768
x = torch.randn(2, 128, hidden)
validator.validate(
    kernel_fn    = lambda x: F.layer_norm(x, [hidden]),
    reference_fn = lambda x: F.layer_norm(x, [hidden]),
    inputs       = {'x': x},
    name         = f'LayerNorm(hidden={hidden})',
    dtype        = dtype,
    atol         = 1e-2,
)

# Test 3: Softmax
validator.validate(
    kernel_fn    = lambda x: F.softmax(x, dim=-1),
    reference_fn = lambda x: F.softmax(x, dim=-1),
    inputs       = {'x': torch.randn(8, 512)},
    name         = 'Softmax(8x512)',
    dtype        = dtype,
    atol         = 1e-2,
)

summary = validator.report()

SummaryΒΆ

Topic

Key AMD Interview Point

Tolerance table

FP16 atol=1e-2, BF16 atol=1e-1 – know these by heart

GEMM sweep

Test all dtype x shape combos, not just square

Softmax edges

+inf input produces NaN (inf/inf); validator must flag this as expected instability

Attention scaling

Error grows with seq_len in FP16 – FlashAttention mitigates this

TFLOPS efficiency

Compare achieved vs theoretical; <60% = investigate

KernelValidator

Pattern AMD/NVIDIA teams use: harness + report

Previous: lab_01_hardware_validation.ipynb
Next: lab_03_model_performance.ipynb β€” Model benchmarking, profiling, LLM throughput
Back to Overview: README.md