Lab 02: Kernel ValidationΒΆ
GEMM Β· Convolution Β· Attention Β· Softmax Β· LayerNormΒΆ
Role alignment: AMD Principal Staff β AI/ML Performance Validation
Reference: 02_kernel_validation.ipynb
What you will do:ΒΆ
Validate GEMM correctness across dtypes (FP32, FP16, BF16) with proper tolerances
Test softmax and layernorm edge cases
Validate scaled dot-product attention (FlashAttention-style)
Measure kernel efficiency vs roofline model
Build a reusable
KernelValidatorclass
The key insight: GPU results are NOT bit-identical to CPU β you must use dtype-aware tolerances
SetupΒΆ
The setup establishes the dtype-aware tolerance table that governs all kernel validation in this lab: FP32 uses atol=1e-4, FP16 uses atol=1e-2, and BF16 uses atol=1e-1. These tolerances reflect the fundamental precision limits of each data type β FP16 has only 10 mantissa bits (~3.3 decimal digits), so expecting better than 1e-2 absolute agreement with an FP32 reference is unrealistic. The tols() helper function provides a clean interface for looking up tolerances by dtype, establishing a pattern used throughout production validation code at AMD and NVIDIA.
import torch
import torch.nn.functional as F
import time
import math
from dataclasses import dataclass
from typing import Optional
if torch.cuda.is_available():
DEVICE = 'cuda'
elif torch.backends.mps.is_available():
DEVICE = 'mps'
else:
DEVICE = 'cpu'
print(f'Device: {DEVICE}')
if DEVICE == 'cuda':
print(f'GPU : {torch.cuda.get_device_name(0)}')
# ---- Tolerance table (AMD/NVIDIA engineering standard) ----
# Lower precision = higher tolerance needed
TOLERANCES = {
torch.float64: {'atol': 1e-7, 'rtol': 1e-5},
torch.float32: {'atol': 1e-4, 'rtol': 1e-4},
torch.float16: {'atol': 1e-2, 'rtol': 1e-3},
torch.bfloat16: {'atol': 1e-1, 'rtol': 1e-2},
}
def tols(dtype):
return TOLERANCES.get(dtype, {'atol': 1e-2, 'rtol': 1e-2})
Concept: Why Tolerances MatterΒΆ
FP32 β 7 decimal digits of precision β atol=1e-4
FP16 β 3 decimal digits of precision β atol=1e-2
BF16 β 2-3 decimal digits of precision β atol=1e-1
INT8 β integer quantization error β much higher
GPU kernels reorder floating-point operations (non-associative).
The reference is FP32 on CPU β GPU result must be within tolerance, not bit-identical.
Exercise 2.1 β GEMM Correctness SweepΒΆ
Validate GPU matrix multiplication against CPU FP32 reference across:
Shapes: square (128β4096), non-square, tall, wide
Dtypes: FP32, FP16, BF16
Pass criteria: max_abs_error β€ atol and max_rel_error β€ rtol
@dataclass
class ValidationResult:
name: str
passed: bool
max_abs_error: float
max_rel_error: float
atol: float
rtol: float
notes: str = ''
def __str__(self):
status = 'PASS' if self.passed else 'FAIL'
return (f'[{status}] {self.name:50s} '
f'abs={self.max_abs_error:.2e} (β€{self.atol:.1e}) '
f'rel={self.max_rel_error:.2e} (β€{self.rtol:.1e})'
+ (f' NOTE: {self.notes}' if self.notes else ''))
def validate_gemm(M, K, N, dtype, device=DEVICE):
"""
Compare GPU GEMM result to FP32 CPU reference.
"""
t = tols(dtype)
# Reference: FP32 on CPU
a_cpu = torch.randn(M, K, dtype=torch.float32)
b_cpu = torch.randn(K, N, dtype=torch.float32)
ref = torch.matmul(a_cpu, b_cpu)
# GPU result in target dtype
a_gpu = a_cpu.to(device=device, dtype=dtype)
b_gpu = b_cpu.to(device=device, dtype=dtype)
out_gpu = torch.matmul(a_gpu, b_gpu).to(dtype=torch.float32).cpu()
abs_err = (out_gpu - ref).abs()
rel_err = abs_err / (ref.abs() + 1e-8)
max_abs = abs_err.max().item()
max_rel = rel_err.max().item()
passed = (max_abs <= t['atol']) and (max_rel <= t['rtol'])
name = f'GEMM({M}x{K}x{N}) {str(dtype).split(".")[1]}'
return ValidationResult(name, passed, max_abs, max_rel, t['atol'], t['rtol'])
# Sweep
results = []
shapes = [(128, 128, 128), (512, 512, 512), (1024, 1024, 1024),
(2048, 2048, 2048), (1024, 4096, 1024), (256, 1024, 8192)]
dtypes = [torch.float32, torch.float16]
if DEVICE == 'cuda':
dtypes.append(torch.bfloat16)
# Reduce on CPU
if DEVICE == 'cpu':
shapes = [(128, 128, 128), (256, 256, 256), (512, 256, 512)]
dtypes = [torch.float32]
for dtype in dtypes:
for M, K, N in shapes:
r = validate_gemm(M, K, N, dtype)
results.append(r)
print(r)
passed = sum(1 for r in results if r.passed)
print(f'\n{passed}/{len(results)} GEMM tests passed')
Exercise 2.2 β Softmax Edge CasesΒΆ
Softmax is numerically tricky. Test edge cases that expose GPU kernel bugs:
All zeros
All same value
One very large value (should dominate)
Mixed large + small (overflow risk in naive implementation)
def validate_softmax(input_cpu, name, expect_numerical_issue=False):
"""Compare GPU softmax to CPU FP32 reference.
Args:
input_cpu: Input tensor on CPU.
name: Human-readable test name.
expect_numerical_issue: If True, the test PASSES when NaN/Inf is
detected in the output (i.e. we successfully caught instability).
"""
ref = F.softmax(input_cpu, dim=-1)
gpu_input = input_cpu.to(DEVICE)
out = F.softmax(gpu_input, dim=-1).cpu()
# Softmax properties:
# 1. Output sums to 1
# 2. All values in [0, 1]
# 3. Close to CPU reference
sum_check = abs(out.sum().item() - 1.0) < 1e-4
range_check = (out >= 0).all() and (out <= 1).all()
has_nan = torch.isnan(out).any().item()
has_inf = torch.isinf(out).any().item()
close = torch.allclose(out.float(), ref.float(), atol=1e-4)
issues = []
if not sum_check: issues.append(f'sum={out.sum().item():.6f} (expected 1.0)')
if not range_check: issues.append('values outside [0,1]')
if has_nan: issues.append('NaN detected')
if has_inf: issues.append('Inf detected')
if not close: issues.append(f'max_diff={(out.float()-ref.float()).abs().max().item():.2e}')
numerically_stable = (not has_nan) and (not has_inf)
if expect_numerical_issue:
# We EXPECT instability here β pass means we correctly detected the problem
passed = not numerically_stable # True when NaN or Inf is present
if passed:
issues.append('EXPECTED: numerical instability correctly detected')
else:
issues.append('UNEXPECTED: no instability detected for pathological input')
else:
passed = sum_check and range_check and numerically_stable
status = 'PASS' if passed else 'FAIL'
print(f'[{status}] {name:45s} sum_ok={sum_check} range_ok={range_check} '
f'nan={has_nan} {", ".join(issues) if issues else "all checks pass"}')
return passed
SEQ = 2048
test_cases = [
(torch.zeros(SEQ), 'All zeros', False),
(torch.ones(SEQ), 'All ones (same value)', False),
(torch.full((SEQ,), 1000.0), 'All large same value', False),
(torch.cat([torch.tensor([1000.0]), torch.zeros(SEQ-1)]), 'One large, rest zeros', False),
(torch.cat([torch.tensor([1000.0, -1000.0]), torch.randn(SEQ-2)]), 'Large + small mix', False),
(torch.randn(SEQ) * 10, 'High variance random', False),
(torch.full((SEQ,), float('inf')), 'All inf (must produce NaN)', True),
]
print('Softmax edge case validation:')
print('-' * 90)
for tensor, name, expect_issue in test_cases:
try:
validate_softmax(tensor, name, expect_numerical_issue=expect_issue)
except Exception as e:
print(f'[EXCEPTION] {name}: {e}')
Exercise 2.3 β LayerNorm ValidationΒΆ
LayerNorm is in every Transformer. Test:
Basic correctness
Near-zero variance (numerical instability risk)
Large hidden dimensions
def validate_layernorm(batch, seq, hidden, name, eps=1e-5):
"""Compare GPU LayerNorm to manual FP32 CPU computation."""
x_cpu = torch.randn(batch, seq, hidden, dtype=torch.float32)
weight = torch.ones(hidden)
bias = torch.zeros(hidden)
# Reference: manual FP32
mean = x_cpu.mean(dim=-1, keepdim=True)
var = x_cpu.var(dim=-1, keepdim=True, unbiased=False)
ref = (x_cpu - mean) / (var + eps).sqrt() * weight + bias
# GPU via torch.nn.functional
x_gpu = x_cpu.to(DEVICE)
w_gpu = weight.to(DEVICE)
b_gpu = bias.to(DEVICE)
out = F.layer_norm(x_gpu, [hidden], w_gpu, b_gpu, eps=eps).cpu()
max_err = (out - ref).abs().max().item()
passed = max_err < 1e-3
status = 'PASS' if passed else 'FAIL'
print(f'[{status}] {name:45s} max_abs_err={max_err:.2e}')
return passed
print('LayerNorm validation:')
print('-' * 70)
validate_layernorm(2, 128, 768, 'BERT-base (batch=2, seq=128, h=768)')
validate_layernorm(1, 2048, 4096, 'LLM-style (batch=1, seq=2048, h=4096)')
validate_layernorm(8, 512, 256, 'Small hidden (h=256)')
validate_layernorm(1, 1, 8192, 'Large hidden only (h=8192)')
Exercise 2.4 β Scaled Dot-Product Attention ValidationΒΆ
Test the attention kernel used in every LLM:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
Compare PyTorchβs optimized SDPA (which uses FlashAttention on CUDA) vs manual reference.
def manual_attention(Q, K, V, mask=None):
"""Reference implementation in FP32."""
d_k = Q.shape[-1]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, V), attn_weights
def validate_attention(batch, heads, seq_len, head_dim, dtype, use_causal=False):
t = tols(dtype)
name = f'Attention(B={batch},H={heads},S={seq_len},D={head_dim}) {str(dtype).split(".")[1]}'
if use_causal:
name += ' causal'
# Reference: FP32 CPU manual
Q_cpu = torch.randn(batch, heads, seq_len, head_dim)
K_cpu = torch.randn(batch, heads, seq_len, head_dim)
V_cpu = torch.randn(batch, heads, seq_len, head_dim)
ref_out, _ = manual_attention(Q_cpu, K_cpu, V_cpu)
# GPU SDPA (uses FlashAttention on CUDA >= 8.0)
Q_gpu = Q_cpu.to(device=DEVICE, dtype=dtype)
K_gpu = K_cpu.to(device=DEVICE, dtype=dtype)
V_gpu = V_cpu.to(device=DEVICE, dtype=dtype)
with torch.no_grad():
out_gpu = F.scaled_dot_product_attention(
Q_gpu, K_gpu, V_gpu,
is_causal=use_causal
).to(torch.float32).cpu()
# Compare (causal mask shifts result, only compare first half)
compare_ref = ref_out
compare_out = out_gpu
if use_causal:
# First token = same (only attends to itself)
# Use looser tolerance for causal since mask changes the computation
compare_ref = ref_out[:, :, :1, :]
compare_out = out_gpu[:, :, :1, :]
max_err = (compare_out - compare_ref).abs().max().item()
passed = max_err <= t['atol'] * 10 # slightly looser for attention due to softmax chain
no_nan = not torch.isnan(out_gpu).any()
no_inf = not torch.isinf(out_gpu).any()
overall = passed and no_nan and no_inf
status = 'PASS' if overall else 'FAIL'
notes = []
if not passed: notes.append(f'max_err={max_err:.2e}>atol={t["atol"]*10:.1e}')
if not no_nan: notes.append('NaN')
if not no_inf: notes.append('Inf')
print(f'[{status}] {name:60s} max_err={max_err:.2e} {" ".join(notes)}')
return overall
print('Attention kernel validation:')
print('-' * 90)
attn_dtypes = [torch.float32, torch.float16]
if DEVICE == 'cuda':
attn_dtypes.append(torch.bfloat16)
if DEVICE == 'cpu':
attn_dtypes = [torch.float32]
configs = [
(1, 8, 128, 64, False), # BERT-style
(1, 32, 512, 128, False), # LLM decode
(1, 16, 2048, 64, True), # Causal LLM
]
if DEVICE == 'cpu':
configs = [(1, 4, 64, 32, False), (1, 4, 64, 32, True)]
all_results = []
for dtype in attn_dtypes:
for B, H, S, D, causal in configs:
try:
r = validate_attention(B, H, S, D, dtype, causal)
all_results.append(r)
except Exception as e:
print(f'[ERROR] {e}')
print(f'\n{sum(all_results)}/{len(all_results)} attention tests passed')
Exercise 2.5 β Sequence Length ScalingΒΆ
Test numerical error growth as sequence length increases.
Long sequences accumulate floating-point error β this is a real issue at seq_len > 8192.
seq_lengths = [64, 128, 256, 512, 1024, 2048] if DEVICE != 'cpu' else [32, 64, 128]
dtype = torch.float16 if DEVICE == 'cuda' else torch.float32
print(f'Attention error vs sequence length ({str(dtype).split(".")[1]}):')
print(f"{'Seq Len':>10} {'Max Abs Error':>15} {'Max Rel Error':>15} Status")
print('-' * 60)
scaling_data = []
for seq_len in seq_lengths:
B, H, D = 1, 8, 64
Q_cpu = torch.randn(B, H, seq_len, D)
K_cpu = torch.randn(B, H, seq_len, D)
V_cpu = torch.randn(B, H, seq_len, D)
ref, _ = manual_attention(Q_cpu, K_cpu, V_cpu)
Q_g = Q_cpu.to(DEVICE, dtype=dtype)
K_g = K_cpu.to(DEVICE, dtype=dtype)
V_g = V_cpu.to(DEVICE, dtype=dtype)
out = F.scaled_dot_product_attention(Q_g, K_g, V_g).float().cpu()
abs_err = (out - ref).abs()
rel_err = abs_err / (ref.abs() + 1e-8)
max_abs = abs_err.max().item()
max_rel = rel_err.max().item()
ok = max_abs < 0.1
scaling_data.append((seq_len, max_abs, max_rel))
status = 'PASS' if ok else 'WARN'
print(f'{seq_len:>10} {max_abs:>15.4e} {max_rel:>15.4e} {status}')
print('\nNote: Error growing with seq_len is expected for FP16/BF16 β watch for sudden jumps')
Exercise 2.6 β GEMM Performance vs RooflineΒΆ
Measure achieved FLOPS for GEMM and compare to theoretical GPU peak.
This tells you whether a kernel is compute-bound or memory-bound.
def measure_gemm_tflops(M, K, N, dtype=torch.float16, iterations=100):
"""
Measure GEMM throughput in TFLOPS.
GEMM FLOPs = 2 * M * K * N (multiply + add for each element)
"""
if DEVICE == 'cpu':
dtype = torch.float32
iterations = 10
a = torch.randn(M, K, device=DEVICE, dtype=dtype)
b = torch.randn(K, N, device=DEVICE, dtype=dtype)
# Warmup
for _ in range(10):
_ = torch.matmul(a, b)
if DEVICE == 'cuda':
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iterations):
_ = torch.matmul(a, b)
if DEVICE == 'cuda':
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
flops_per_iter = 2 * M * K * N
total_flops = flops_per_iter * iterations
tflops = total_flops / elapsed / 1e12
del a, b
if DEVICE == 'cuda':
torch.cuda.empty_cache()
return tflops
# Theoretical peak TFLOPS (FP16) for common GPUs
THEORETICAL_TFLOPS = {
'NVIDIA H100': 989.4,
'NVIDIA A100-SXM': 312.0,
'NVIDIA RTX 4090': 165.2,
'NVIDIA RTX 3090': 71.0,
'AMD MI300X': 1307.4,
'AMD MI250X': 383.0,
}
sizes = [(4096, 4096, 4096), (8192, 8192, 8192)] if DEVICE != 'cpu' else [(256, 256, 256)]
dtype = torch.float16 if DEVICE == 'cuda' else torch.float32
print(f'GEMM Performance ({str(dtype).split(".")[1]}):')
print(f"{'Shape':>25} {'TFLOPS':>10} {'% of Peak':>12}")
print('-' * 55)
if DEVICE == 'cuda':
gpu_name = torch.cuda.get_device_name(0)
peak = next((v for k, v in THEORETICAL_TFLOPS.items() if k in gpu_name), 100.0)
print(f'GPU: {gpu_name} | Theoretical FP16 peak: {peak} TFLOPS\n')
else:
peak = 1.0
for M, K, N in sizes:
achieved = measure_gemm_tflops(M, K, N, dtype)
pct = achieved / peak * 100 if DEVICE == 'cuda' else 0
status = 'PASS' if pct >= 60 or DEVICE != 'cuda' else 'WARN'
print(f'{f"({M},{K},{N})":>25} {achieved:>10.1f} {pct:>11.1f}% [{status}]')
print('\nRule of thumb: β₯60% of peak = efficient kernel; <40% = likely memory-bound or misconfigured')
Exercise 2.7 β KernelValidator Class (Reusable)ΒΆ
Build a reusable test harness β this is what AMD validation teams actually build and maintain.
import json
from pathlib import Path
from datetime import datetime
class KernelValidator:
"""
Reusable kernel correctness + performance validator.
Stores results and produces a structured report.
"""
def __init__(self, device=DEVICE):
self.device = device
self.results = []
def validate(
self,
kernel_fn, # GPU kernel to test
reference_fn, # CPU FP32 reference
inputs, # dict of {name: tensor} on CPU FP32
name: str,
dtype=torch.float32,
atol: Optional[float] = None,
rtol: Optional[float] = None,
):
t = tols(dtype)
atol = atol or t['atol']
rtol = rtol or t['rtol']
# CPU reference
with torch.no_grad():
ref = reference_fn(**inputs)
# GPU execution
gpu_inputs = {k: v.to(device=self.device, dtype=dtype) for k, v in inputs.items()}
with torch.no_grad():
out = kernel_fn(**gpu_inputs)
if isinstance(out, tuple):
out = out[0] # take first output
out = out.to(dtype=torch.float32).cpu()
# Metrics
abs_err = (out - ref).abs()
max_abs = abs_err.max().item()
max_rel = (abs_err / (ref.abs() + 1e-8)).max().item()
no_nan = not torch.isnan(out).any().item()
no_inf = not torch.isinf(out).any().item()
passed = (max_abs <= atol) and (max_rel <= rtol) and no_nan and no_inf
result = {
'name': name,
'dtype': str(dtype).split('.')[-1],
'passed': passed,
'max_abs_error': round(max_abs, 8),
'max_rel_error': round(max_rel, 8),
'atol': atol,
'rtol': rtol,
'nan': not no_nan,
'inf': not no_inf,
}
self.results.append(result)
status = 'PASS' if passed else 'FAIL'
print(f'[{status}] {name:50s} abs={max_abs:.2e} rel={max_rel:.2e}')
return passed
def report(self, path='lab02_kernel_report.json'):
summary = {
'generated_at': datetime.now().isoformat(),
'device': self.device,
'gpu': torch.cuda.get_device_name(0) if self.device == 'cuda' else self.device,
'total': len(self.results),
'passed': sum(1 for r in self.results if r['passed']),
'failed': sum(1 for r in self.results if not r['passed']),
'tests': self.results,
}
Path(path).write_text(json.dumps(summary, indent=2))
print(f'\nReport: {summary["passed"]}/{summary["total"]} passed β {path}')
return summary
# ---- Demo: use KernelValidator ----
validator = KernelValidator()
size = 512 if DEVICE != 'cpu' else 128
dtype = torch.float16 if DEVICE == 'cuda' else torch.float32
# Test 1: GEMM
validator.validate(
kernel_fn = lambda a, b: torch.matmul(a, b),
reference_fn= lambda a, b: torch.matmul(a, b),
inputs = {'a': torch.randn(size, size), 'b': torch.randn(size, size)},
name = f'GEMM({size}x{size})',
dtype = dtype,
)
# Test 2: LayerNorm
hidden = 768
x = torch.randn(2, 128, hidden)
validator.validate(
kernel_fn = lambda x: F.layer_norm(x, [hidden]),
reference_fn = lambda x: F.layer_norm(x, [hidden]),
inputs = {'x': x},
name = f'LayerNorm(hidden={hidden})',
dtype = dtype,
atol = 1e-2,
)
# Test 3: Softmax
validator.validate(
kernel_fn = lambda x: F.softmax(x, dim=-1),
reference_fn = lambda x: F.softmax(x, dim=-1),
inputs = {'x': torch.randn(8, 512)},
name = 'Softmax(8x512)',
dtype = dtype,
atol = 1e-2,
)
summary = validator.report()
SummaryΒΆ
Topic |
Key AMD Interview Point |
|---|---|
Tolerance table |
FP16 atol=1e-2, BF16 atol=1e-1 β know these by heart |
GEMM sweep |
Test all dtype x shape combos, not just square |
Softmax edges |
|
Attention scaling |
Error grows with seq_len in FP16 β FlashAttention mitigates this |
TFLOPS efficiency |
Compare achieved vs theoretical; <60% = investigate |
KernelValidator |
Pattern AMD/NVIDIA teams use: harness + report |
Previous: lab_01_hardware_validation.ipynb
Next: lab_03_model_performance.ipynb β Model benchmarking, profiling, LLM throughput
Back to Overview: README.md