Section 2: Kernel ValidationΒΆ

GEMM, Convolution, Attention, Softmax & LayerNormΒΆ

Duration: 6 hours
Difficulty: Intermediate–Advanced

2.1 Why Kernel Validation MattersΒΆ

Every ML model ultimately executes a small set of compute kernels on hardware:

Kernel

Used By

% of Compute

GEMM (General Matrix Multiply)

Every layer (linear, attention)

60–80%

Convolution

CNN layers (ResNet, YOLO, etc.)

50–70% in CV

Attention (Scaled Dot-Product)

Transformers (LLMs, ViT)

30–50%

Softmax

Attention, classification heads

5–10%

LayerNorm / RMSNorm

Every transformer block

5–10%

If any of these kernels produce incorrect results, the model silently outputs garbage. Unlike a crash, numerical errors are insidious β€” the model runs but produces subtly wrong results.

2.2 Numerical Precision FundamentalsΒΆ

Data Types in AI ComputeΒΆ

Type

Bits

Exponent

Mantissa

Range

Precision

FP32

32

8

23

Β±3.4e38

~7 decimal digits

FP16

16

5

10

Β±65504

~3.3 decimal digits

BF16

16

8

7

Β±3.4e38

~2.4 decimal digits

FP8 (E4M3)

8

4

3

Β±448

~1.5 decimal digits

FP8 (E5M2)

8

5

2

Β±57344

~1 decimal digit

INT8

8

N/A

N/A

-128 to 127

Exact (integer)

Tolerance Guidelines for ValidationΒΆ

# Standard tolerances for kernel validation
TOLERANCES = {
    "fp32": {"atol": 1e-5, "rtol": 1e-4},
    "fp16": {"atol": 1e-3, "rtol": 1e-2},
    "bf16": {"atol": 1e-2, "rtol": 5e-2},
    "fp8_e4m3": {"atol": 5e-2, "rtol": 1e-1},
    "fp8_e5m2": {"atol": 1e-1, "rtol": 2e-1},
    "int8": {"atol": 1, "rtol": 0},  # Integer: exact or off-by-one
}

Understanding atol vs rtolΒΆ

Numerical validation of GPU kernels relies on two tolerance parameters: absolute tolerance (atol) and relative tolerance (rtol). The combined check is \(|\text{actual} - \text{expected}| \leq \text{atol} + \text{rtol} \times |\text{expected}|\). Absolute tolerance dominates when values are near zero (where relative error would explode), while relative tolerance dominates for large values where a fixed absolute threshold would be too strict. For FP16 kernels, atol=1e-3 and rtol=1e-2 are industry-standard thresholds because FP16 mantissa provides only ~3.3 decimal digits of precision. The validate_close function below reports not just pass/fail but the number and percentage of element-level violations, which helps distinguish between a systemic kernel bug (many violations) and an isolated edge case (few violations near the boundary).

import torch

def validate_close(actual, expected, atol, rtol, name="tensor"):
    """Validate two tensors are close within tolerance.

    The check is: |actual - expected| <= atol + rtol * |expected|
    """
    if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
        diff = (actual - expected).abs()
        max_diff = diff.max().item()
        mean_diff = diff.mean().item()
        threshold = atol + rtol * expected.abs()
        num_violations = (diff > threshold).sum().item()
        total = actual.numel()
        print(f"FAIL [{name}]: max_diff={max_diff:.6f}, "
              f"mean_diff={mean_diff:.6f}, "
              f"violations={num_violations}/{total} "
              f"({100*num_violations/total:.2f}%)")
        return False
    print(f"PASS [{name}]")
    return True

2.3 GEMM ValidationΒΆ

GEMM is the most critical kernel β€” it dominates compute time in every neural network.

Basic GEMM Correctness TestΒΆ

import torch

def validate_gemm(M, N, K, dtype=torch.float16, device='cuda'):
    """Validate GPU GEMM against CPU FP32 reference."""
    # Generate inputs
    torch.manual_seed(42)
    A = torch.randn(M, K, device='cpu', dtype=torch.float32)
    B = torch.randn(K, N, device='cpu', dtype=torch.float32)

    # CPU reference (FP32 β€” highest accuracy)
    C_ref = torch.matmul(A, B)

    # GPU computation in target dtype
    A_gpu = A.to(device=device, dtype=dtype)
    B_gpu = B.to(device=device, dtype=dtype)
    C_gpu = torch.matmul(A_gpu, B_gpu)

    # Compare (bring GPU result back to CPU FP32)
    C_gpu_fp32 = C_gpu.float().cpu()

    tol = TOLERANCES[str(dtype).split('.')[-1]]
    return validate_close(C_gpu_fp32, C_ref, **tol, name=f"GEMM_{M}x{N}x{K}_{dtype}")


# Test matrix: varying sizes and dtypes
test_cases = [
    # (M, N, K, dtype)
    (128, 128, 128, torch.float16),
    (1024, 1024, 1024, torch.float16),
    (4096, 4096, 4096, torch.float16),
    (8192, 8192, 8192, torch.float16),
    (1, 1024, 4096, torch.float16),     # Skinny matrix (inference)
    (32, 50257, 4096, torch.float16),   # Vocab projection
    (128, 128, 128, torch.bfloat16),
    (4096, 4096, 4096, torch.bfloat16),
]

for M, N, K, dtype in test_cases:
    validate_gemm(M, N, K, dtype)

Edge Cases for GEMMΒΆ

Production GEMM kernels must handle inputs that stress numerical edge cases beyond standard random data. Zero matrices verify the kernel produces exact zeros (no residual accumulation). Identity matrices test that \(A \cdot I = A\) holds within tolerance. Large uniform values (e.g., 1000.0) stress the accumulator: multiplying two 256x256 matrices of 1000s produces elements of \(256 \times 10^6\), risking FP16 overflow (\(\text{max} = 65504\)). Small values near the denormalized range test whether the kernel flushes subnormals correctly. Non-square shapes and transposed inputs exercise different memory access patterns and tiling strategies in the GPU kernel, which are common sources of bugs in optimized GEMM implementations like cuBLAS and hipBLAS.

def gemm_edge_cases(device='cuda'):
    """Test GEMM with tricky inputs."""
    cases = {
        "zeros": (torch.zeros(256, 256), torch.zeros(256, 256)),
        "identity": (torch.eye(256), torch.randn(256, 256)),
        "large_values": (torch.full((256, 256), 1000.0), torch.full((256, 256), 1000.0)),
        "small_values": (torch.full((256, 256), 1e-6), torch.full((256, 256), 1e-6)),
        "mixed_sign": (torch.randn(256, 256), -torch.randn(256, 256)),
        "non_square": (torch.randn(1, 4096), torch.randn(4096, 50257)),
        "transpose_A": (torch.randn(256, 512).T, torch.randn(256, 256)),
    }

    for name, (A, B) in cases.items():
        ref = torch.matmul(A, B)
        gpu_result = torch.matmul(A.cuda(), B.cuda()).cpu()
        passed = torch.allclose(gpu_result, ref, atol=1e-4, rtol=1e-3)
        print(f"{'PASS' if passed else 'FAIL'} [GEMM edge: {name}]")

Batched GEMMΒΆ

Batched GEMM (torch.bmm) computes independent matrix multiplications across a batch dimension simultaneously, which is the core operation inside multi-head attention. For a transformer with \(h\) heads and batch size \(b\), the attention score computation is a batched GEMM of shape \((b \times h, s, d_k)\), where \(s\) is the sequence length and \(d_k\) is the head dimension. Validation must cover realistic transformer shapes – for example, Llama-2 70B uses 64 heads with \(d_k = 128\) and sequence lengths up to 4096. The batch dimension changes how the GPU schedules thread blocks and manages shared memory, so bugs that do not appear in single GEMM can surface in batched mode, particularly at large batch counts where the kernel may switch to a different tiling algorithm.

def validate_batched_gemm(batch, M, N, K, dtype=torch.float16):
    """Validate batched GEMM (used in multi-head attention)."""
    torch.manual_seed(42)
    A = torch.randn(batch, M, K, dtype=torch.float32)
    B = torch.randn(batch, K, N, dtype=torch.float32)

    C_ref = torch.bmm(A, B)
    C_gpu = torch.bmm(A.to('cuda', dtype), B.to('cuda', dtype)).float().cpu()

    tol = TOLERANCES[str(dtype).split('.')[-1]]
    return validate_close(C_gpu, C_ref, **tol,
                          name=f"BatchGEMM_{batch}x{M}x{N}x{K}")

# Multi-head attention shapes: (batch*heads, seq_len, head_dim)
validate_batched_gemm(96, 2048, 128, 128)   # 12 heads, batch=8
validate_batched_gemm(256, 4096, 128, 128)  # 32 heads, batch=8

2.4 Convolution ValidationΒΆ

2D Convolution CorrectnessΒΆ

Convolution kernels are the backbone of computer vision models, and hardware vendors must validate them across the full range of configurations used in production architectures. The reference is always FP32 on CPU using torch.nn.functional.conv2d, and the GPU result in the target dtype must match within the appropriate tolerance. Key configurations include ResNet’s 7x7 stem convolution with stride 2 (which tests large kernel + strided access), 3x3 residual blocks with padding (the most common case), and downsampling convolutions with stride 2 (which test non-unit stride address generation). The cuDNN and MIOpen libraries select different algorithms (Winograd, FFT, implicit GEMM) depending on input size and kernel shape, so each configuration may exercise a completely different code path on the GPU – validation must cover them all.

import torch
import torch.nn.functional as F

def validate_conv2d(batch, in_c, out_c, H, W, kernel, stride=1, padding=0,
                    dtype=torch.float16):
    """Validate GPU conv2d against CPU FP32 reference."""
    torch.manual_seed(42)
    x = torch.randn(batch, in_c, H, W, dtype=torch.float32)
    w = torch.randn(out_c, in_c, kernel, kernel, dtype=torch.float32)
    b = torch.randn(out_c, dtype=torch.float32)

    # CPU reference
    y_ref = F.conv2d(x, w, b, stride=stride, padding=padding)

    # GPU in target dtype
    y_gpu = F.conv2d(
        x.to('cuda', dtype), w.to('cuda', dtype), b.to('cuda', dtype),
        stride=stride, padding=padding
    ).float().cpu()

    tol = TOLERANCES[str(dtype).split('.')[-1]]
    return validate_close(y_gpu, y_ref, **tol,
                          name=f"Conv2d_{batch}x{in_c}x{H}x{W}_k{kernel}")

# Common architectures
test_cases = [
    # (batch, in_c, out_c, H, W, kernel, stride, padding)
    (1, 3, 64, 224, 224, 7, 2, 3),       # ResNet first conv
    (8, 64, 128, 56, 56, 3, 1, 1),       # ResNet block
    (8, 256, 512, 28, 28, 3, 2, 1),      # Downsampling conv
    (1, 3, 32, 640, 640, 3, 1, 1),       # YOLO input
    (8, 512, 512, 14, 14, 3, 1, 1),      # Deep feature maps
]

for case in test_cases:
    validate_conv2d(*case)

Depthwise Convolution (Mobile Architectures)ΒΆ

Depthwise separable convolutions apply a separate filter per input channel (groups=channels), reducing computation from \(O(C_{in} \times C_{out} \times K^2 \times H \times W)\) to \(O(C \times K^2 \times H \times W)\). They are the foundation of MobileNet, EfficientNet, and other mobile-optimized architectures increasingly deployed on edge AI accelerators. Depthwise convolutions have very different memory access patterns than standard convolutions – low arithmetic intensity and high memory bandwidth demand – which means they exercise a different kernel implementation on the GPU. Validating them separately ensures the hardware vendor’s depthwise-specific kernel path is numerically correct.

def validate_depthwise_conv(batch, channels, H, W, kernel=3):
    """Validate depthwise conv (MobileNet, EfficientNet)."""
    torch.manual_seed(42)
    x = torch.randn(batch, channels, H, W, dtype=torch.float32)
    w = torch.randn(channels, 1, kernel, kernel, dtype=torch.float32)

    y_ref = F.conv2d(x, w, groups=channels, padding=kernel//2)
    y_gpu = F.conv2d(
        x.cuda().half(), w.cuda().half(), groups=channels, padding=kernel//2
    ).float().cpu()

    return validate_close(y_gpu, y_ref, atol=1e-3, rtol=1e-2,
                          name=f"DepthwiseConv_{channels}ch")

2.5 Attention Kernel ValidationΒΆ

Scaled Dot-Product AttentionΒΆ

Scaled dot-product attention computes \(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\), and it is the single most performance-critical kernel in transformer models. The reference implementation computes the full attention matrix in FP32, while the GPU version dispatches to optimized backends like FlashAttention (which tiles the computation to avoid materializing the \(O(s^2)\) attention matrix). Validation must cover both non-causal (BERT-style bidirectional) and causal (GPT-style autoregressive) masking, since the causal mask changes the softmax normalization for each row. The test matrix below covers configurations from BERT-base (12 heads, seq 512) through Llama 70B (40+ heads, seq 4096+), ensuring correctness across the full range of production transformer shapes.

import math

def validate_attention(batch, heads, seq_len, head_dim, dtype=torch.float16,
                       use_causal_mask=False):
    """Validate scaled dot-product attention against reference."""
    torch.manual_seed(42)
    Q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float32)
    K = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float32)
    V = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float32)

    # Reference implementation (CPU, FP32)
    scale = 1.0 / math.sqrt(head_dim)
    attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * scale

    if use_causal_mask:
        mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
        attn_weights.masked_fill_(mask, float('-inf'))

    attn_weights = torch.softmax(attn_weights, dim=-1)
    y_ref = torch.matmul(attn_weights, V)

    # GPU implementation
    Q_gpu = Q.to('cuda', dtype)
    K_gpu = K.to('cuda', dtype)
    V_gpu = V.to('cuda', dtype)

    # Use PyTorch's SDPA (dispatches to FlashAttention/efficient kernels)
    y_gpu = torch.nn.functional.scaled_dot_product_attention(
        Q_gpu, K_gpu, V_gpu, is_causal=use_causal_mask
    ).float().cpu()

    tol = TOLERANCES[str(dtype).split('.')[-1]]
    causal_str = "_causal" if use_causal_mask else ""
    return validate_close(y_gpu, y_ref, **tol,
                          name=f"Attn_{batch}x{heads}x{seq_len}x{head_dim}{causal_str}")


# Test common transformer configurations
configs = [
    # (batch, heads, seq_len, head_dim, dtype, causal)
    (1, 12, 512, 64, torch.float16, False),     # BERT-base
    (1, 12, 512, 64, torch.float16, True),      # GPT-2
    (1, 32, 2048, 128, torch.float16, True),     # Llama 7B
    (1, 40, 4096, 128, torch.float16, True),     # Llama 13B
    (1, 32, 2048, 128, torch.bfloat16, True),    # BF16 variant
    (8, 32, 512, 128, torch.float16, True),      # Batched
]

for b, h, s, d, dt, causal in configs:
    validate_attention(b, h, s, d, dt, causal)

FlashAttention vs Standard AttentionΒΆ

PyTorch’s scaled_dot_product_attention dispatches to one of three backends: the math (naive) implementation, FlashAttention, and the memory-efficient attention kernel. Each backend uses a different algorithm with different numerical characteristics – FlashAttention performs online softmax with block-level rescaling, which introduces small rounding differences compared to the standard global softmax. This cross-comparison validates that all three backends produce results within tolerance of each other, which is essential because PyTorch automatically selects the backend based on input shape, dtype, and available hardware. A mismatch beyond tolerance indicates a kernel bug in one of the backends, which is a release blocker for any hardware vendor shipping optimized attention implementations.

from torch.nn.attention import sdpa_kernel, SDPBackend

def compare_attention_backends(batch, heads, seq_len, head_dim):
    """Compare different attention implementations."""
    torch.manual_seed(42)
    Q = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
    K = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
    V = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)

    # Math backend (naive)
    with sdpa_kernel(SDPBackend.MATH):
        y_math = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)

    # Flash Attention backend
    with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        y_flash = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)

    # Memory-efficient backend
    with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
        y_efficient = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)

    # Cross-compare
    print("Flash vs Math:", torch.allclose(y_flash, y_math, atol=1e-3, rtol=1e-2))
    print("Efficient vs Math:", torch.allclose(y_efficient, y_math, atol=1e-3, rtol=1e-2))
    print("Flash vs Efficient:", torch.allclose(y_flash, y_efficient, atol=1e-3, rtol=1e-2))

2.6 Softmax ValidationΒΆ

Softmax normalizes logits to a probability distribution: \(\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}\). GPU kernels use the numerically stable variant that subtracts the maximum value before exponentiation to prevent overflow, but edge cases can still cause failures. Validation verifies three properties: output sums to 1.0, all values lie in \([0, 1]\), and results match the CPU FP32 reference. The stability test with large input values (e.g., 1000.0 in FP16) is critical – a naive implementation would overflow at \(e^{1000}\), while the stable version computes \(e^{1000 - 1000} = e^0 = 1\). Softmax appears in every attention layer and every classification head, making it one of the highest-frequency kernels in production AI workloads.

def validate_softmax(shape, dim=-1, dtype=torch.float16):
    """Validate softmax kernel correctness."""
    torch.manual_seed(42)
    x = torch.randn(*shape, dtype=torch.float32)

    # CPU reference
    y_ref = torch.softmax(x, dim=dim)

    # GPU
    y_gpu = torch.softmax(x.to('cuda', dtype), dim=dim).float().cpu()

    tol = TOLERANCES[str(dtype).split('.')[-1]]
    return validate_close(y_gpu, y_ref, **tol,
                          name=f"Softmax_{shape}_dim{dim}")


# Standard cases
validate_softmax((1, 50257), dim=-1)                    # Vocab logits
validate_softmax((8, 32, 2048, 2048), dim=-1)           # Attention weights
validate_softmax((1, 1000), dim=-1)                      # Classification

# Edge cases
validate_softmax((1, 50257), dim=-1, dtype=torch.bfloat16)

# Numerical stability: large values should not overflow
def softmax_stability_test():
    """Ensure softmax handles large values (numerical stability)."""
    x = torch.tensor([[1000.0, 1001.0, 1002.0]], device='cuda', dtype=torch.float16)
    y = torch.softmax(x, dim=-1)
    assert not torch.isnan(y).any(), "Softmax produced NaN with large inputs"
    assert torch.allclose(y.sum(dim=-1), torch.ones(1, device='cuda', dtype=torch.float16),
                          atol=1e-3), "Softmax doesn't sum to 1"
    print(f"PASS [Softmax stability]: {y}")

softmax_stability_test()

2.7 LayerNorm / RMSNorm ValidationΒΆ

LayerNormΒΆ

Layer Normalization computes \(\text{LN}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta\), where \(\mu\) and \(\sigma^2\) are the mean and variance across the last dimension (hidden size). It appears in every transformer block (typically twice: before attention and before FFN), making it one of the most frequently called kernels. Validation compares the GPU kernel against a CPU FP32 reference across realistic hidden dimensions: 768 (BERT-base), 4096 (Llama 7B), and 5120 (Llama 13B). The key numerical concern is the variance computation – when input values are nearly identical, \(\sigma^2 \approx 0\), and the division can amplify floating-point error. The eps parameter (typically \(10^{-5}\) or \(10^{-6}\)) prevents division by zero, but its interaction with reduced-precision dtypes must be validated.

def validate_layernorm(batch, seq_len, hidden, dtype=torch.float16):
    """Validate LayerNorm kernel."""
    torch.manual_seed(42)
    x = torch.randn(batch, seq_len, hidden, dtype=torch.float32)
    weight = torch.randn(hidden, dtype=torch.float32)
    bias = torch.randn(hidden, dtype=torch.float32)

    # CPU reference
    ln = torch.nn.LayerNorm(hidden)
    ln.weight.data = weight
    ln.bias.data = bias
    y_ref = ln(x)

    # GPU
    ln_gpu = torch.nn.LayerNorm(hidden).cuda()
    ln_gpu.weight.data = weight.cuda().to(dtype)
    ln_gpu.bias.data = bias.cuda().to(dtype)
    y_gpu = ln_gpu(x.to('cuda', dtype)).float().cpu()

    tol = TOLERANCES[str(dtype).split('.')[-1]]
    return validate_close(y_gpu, y_ref, **tol,
                          name=f"LayerNorm_{batch}x{seq_len}x{hidden}")

validate_layernorm(1, 2048, 4096)   # Llama 7B
validate_layernorm(8, 512, 768)     # BERT-base
validate_layernorm(1, 4096, 5120)   # Llama 13B

RMSNorm (Used in Llama, Mistral)ΒΆ

RMSNorm simplifies LayerNorm by removing the mean centering step: \(\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_i x_i^2 + \epsilon}} \cdot \gamma\). This reduces computation by eliminating one reduction pass (the mean), making it ~10-15% faster at the same hidden dimension. Modern LLMs including Llama 2/3, Mistral, and Gemma use RMSNorm exclusively instead of LayerNorm. For hardware validation, RMSNorm is critical because the custom fused kernel implementation on each platform (cuDNN for NVIDIA, MIOpen/composable_kernel for AMD) may differ from the reference Python implementation. The validation below compares GPU results against a reference implementation using torch.mean(x**2) to ensure the root-mean-square computation is numerically correct across production-scale hidden dimensions.

def rmsnorm_reference(x, weight, eps=1e-6):
    """Reference RMSNorm implementation."""
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    return (x / rms) * weight

def validate_rmsnorm(batch, seq_len, hidden, dtype=torch.float16):
    """Validate RMSNorm kernel against reference."""
    torch.manual_seed(42)
    x = torch.randn(batch, seq_len, hidden, dtype=torch.float32)
    weight = torch.randn(hidden, dtype=torch.float32)

    # Reference
    y_ref = rmsnorm_reference(x, weight)

    # GPU
    y_gpu = rmsnorm_reference(
        x.to('cuda', dtype), weight.to('cuda', dtype)
    ).float().cpu()

    tol = TOLERANCES[str(dtype).split('.')[-1]]
    return validate_close(y_gpu, y_ref, **tol,
                          name=f"RMSNorm_{batch}x{seq_len}x{hidden}")

validate_rmsnorm(1, 2048, 4096)
validate_rmsnorm(8, 512, 4096)

2.8 Kernel Performance ProfilingΒΆ

Roofline AnalysisΒΆ

The roofline model characterizes kernel performance as either compute-bound or memory-bound based on arithmetic intensity: the ratio of floating-point operations to bytes accessed. For GEMM, arithmetic intensity is \(\frac{2MNK}{(MK + KN + MN) \times \text{element\_size}}\). Large square matrices (4096x4096) have high arithmetic intensity (~1000+ FLOP/byte) and are compute-bound, meaning performance is limited by tensor core throughput. Skinny matrices like (1, 4096, 4096) used in single-token LLM decode have low arithmetic intensity and are memory-bound, limited by HBM bandwidth. Understanding whether a kernel is compute-bound or memory-bound is essential for hardware validation because it determines which hardware specification (TFLOPS vs GB/s) sets the performance ceiling.

def roofline_analysis(M, N, K, dtype=torch.float16):
    """Compute arithmetic intensity and compare to roofline."""
    import time

    # Compute FLOPS: 2*M*N*K for GEMM
    flops = 2 * M * N * K

    # Compute bytes: read A(M*K) + B(K*N) + write C(M*N)
    elem_size = 2 if dtype == torch.float16 else 4
    bytes_accessed = (M * K + K * N + M * N) * elem_size

    # Arithmetic intensity (FLOP/byte)
    arith_intensity = flops / bytes_accessed

    # Benchmark
    A = torch.randn(M, K, device='cuda', dtype=dtype)
    B = torch.randn(K, N, device='cuda', dtype=dtype)

    # Warmup
    for _ in range(10):
        torch.matmul(A, B)
    torch.cuda.synchronize()

    # Measure
    start = time.perf_counter()
    iterations = 100
    for _ in range(iterations):
        torch.matmul(A, B)
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    achieved_tflops = (flops * iterations) / elapsed / 1e12
    print(f"GEMM {M}x{N}x{K} ({dtype}):")
    print(f"  Arithmetic intensity: {arith_intensity:.1f} FLOP/byte")
    print(f"  Achieved: {achieved_tflops:.1f} TFLOPS")
    print(f"  {'Compute-bound' if arith_intensity > 100 else 'Memory-bound'}")

roofline_analysis(4096, 4096, 4096)
roofline_analysis(1, 4096, 4096)      # Inference: memory-bound
roofline_analysis(8192, 8192, 8192)   # Large: compute-bound

2.9 ExercisesΒΆ

  1. GEMM Sweep: Validate GEMM correctness for all combinations of {FP32, FP16, BF16} Γ— {128, 1024, 4096, 8192} matrix sizes. Report pass/fail and maximum absolute error.

  2. Attention Sequence Length Scaling: Test attention correctness for sequence lengths {128, 512, 1024, 2048, 4096, 8192}. At what point does numerical error grow?

  3. Softmax Edge Cases: Test softmax with inputs of {all zeros, all same value, one very large value, mix of very large + very small}. Does your GPU handle all correctly?

  4. Kernel Comparison: Compare the execution time and numerical accuracy of FlashAttention vs standard attention for sequence length 4096. Which is faster? Which is more accurate?

  5. Custom Kernel Test Harness: Write a reusable KernelValidator class that tests any kernel with configurable shapes, dtypes, and tolerances, and produces a test report.

Key TakeawaysΒΆ

  • Kernel validation is about numerical correctness, not just β€œdoes it run”

  • Always compare against a high-precision reference (FP32 on CPU)

  • Tolerances must be dtype-aware β€” FP8 has ~100x less precision than FP32

  • Edge cases (zeros, large values, non-square matrices) catch real bugs

  • Performance profiling (roofline) tells you if the kernel is efficient

Previous: 01_hardware_validation.ipynb
Next: 03_framework_validation.ipynb Back to Overview: README.md