Lab 06: Framework ValidationΒΆ

PyTorch Backend Β· ONNX Export & Runtime Β· torch.compile Β· Operator CoverageΒΆ

Role alignment: AMD Principal Staff – AI/ML Performance Validation
Reference: 03_framework_validation.ipynb

What you will do:ΒΆ

  1. Validate operator coverage β€” which ops run on GPU vs fall back to CPU

  2. Export a model to ONNX and validate numerical parity

  3. Test torch.compile correctness and speedup

  4. Validate precision consistency across eager / compiled / ONNX

  5. Build a compatibility matrix across PyTorch versions

This is exactly what AMD does when validating PyTorch/ROCm or ONNX Runtime/ROCm releases.

SetupΒΆ

The setup detects the compute device, PyTorch version, CUDA/ROCm version, and checks for optional dependencies (ONNX, ONNX Runtime) that enable additional validation exercises. Framework validation requires knowing the exact software stack because op coverage and numerical behavior can change between PyTorch minor versions – for example, torch.compile behavior differs significantly between PyTorch 2.1 and 2.4.

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import json
from pathlib import Path
from datetime import datetime

if torch.cuda.is_available():
    DEVICE = 'cuda'
elif torch.backends.mps.is_available():
    DEVICE = 'mps'
else:
    DEVICE = 'cpu'

print(f'Device     : {DEVICE}')
print(f'PyTorch    : {torch.__version__}')
if DEVICE == 'cuda':
    print(f'GPU        : {torch.cuda.get_device_name(0)}')
    print(f'CUDA       : {torch.version.cuda}')

# Check optional deps
try:
    import onnx
    import onnxruntime as ort
    ONNX_AVAILABLE = True
    print(f'ONNX       : {onnx.__version__}')
    print(f'ONNXRuntime: {ort.__version__}')
except ImportError:
    ONNX_AVAILABLE = False
    print('ONNX not installed β€” pip install onnx onnxruntime')
    print('Some cells will be skipped')

Test ModelΒΆ

The ValidationModel is designed to exercise the full spectrum of operator types found in production AI models: Conv1d (convolution), MultiheadAttention (scaled dot-product attention), LayerNorm (normalization), Linear (GEMM), GELU (activation), and mean pooling (reduction). By combining these diverse operators in a single model, the validation framework can detect op coverage gaps, incorrect dispatch, and numerical inconsistencies in a single forward pass rather than requiring separate tests for each operator. The model is deliberately small to enable fast iteration during validation.

class ValidationModel(nn.Module):
    """
    Model that exercises diverse operators:
    conv, matmul, layernorm, softmax, gelu, attention
    """

    def __init__(self, hidden=256, num_heads=4):
        super().__init__()
        self.conv = nn.Conv1d(hidden, hidden, kernel_size=3, padding=1)
        self.norm1 = nn.LayerNorm(hidden)
        self.norm2 = nn.LayerNorm(hidden)
        self.attn  = nn.MultiheadAttention(hidden, num_heads, batch_first=True)
        self.fc1   = nn.Linear(hidden, hidden * 4)
        self.fc2   = nn.Linear(hidden * 4, hidden)
        self.out   = nn.Linear(hidden, 16)

    def forward(self, x):  # x: (B, T, H)
        # Conv (transpose for Conv1d: B, H, T)
        x = x + self.conv(x.transpose(1, 2)).transpose(1, 2)
        # Attention
        residual = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x, need_weights=False)
        x = residual + x
        # FFN
        residual = x
        x = self.norm2(x)
        x = self.fc2(F.gelu(self.fc1(x)))
        x = residual + x
        # Pooling + output
        return self.out(x.mean(dim=1))  # (B, 16)


model = ValidationModel(hidden=128 if DEVICE == 'cpu' else 256)
model = model.to(DEVICE).eval()
print(f'Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M')

B, T, H = 2, 64, model.conv.in_channels
dummy_input = torch.randn(B, T, H, device=DEVICE)
with torch.no_grad():
    out = model(dummy_input)
print(f'Output shape: {out.shape}  (expected: ({B}, 16))')

Exercise 6.1 – Operator Coverage CheckΒΆ

Validate every operator in the model runs on GPU β€” no silent CPU fallbacks.

Why this matters for AMD: ROCm doesn’t support 100% of CUDA ops. Missing ops fall back to CPU and destroy performance.

def check_operator_coverage(model, sample_input):
    """
    Run model and intercept all ops.
    Check none of them involve CPU ↔ GPU transfers (which indicate fallback).
    """
    ops_executed = []
    data_transfers = []

    def make_hook(name):
        def hook(module, inp, out):
            # Check if input is on expected device
            for i, tensor in enumerate(inp):
                if isinstance(tensor, torch.Tensor):
                    if str(tensor.device) != DEVICE and str(tensor.device) != f'{DEVICE}:0':
                        data_transfers.append(f'{name}: input {i} on {tensor.device} (expected {DEVICE})')
            ops_executed.append(name)
        return hook

    hooks = []
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # leaf modules only
            hooks.append(module.register_forward_hook(make_hook(name or type(module).__name__)))

    with torch.no_grad():
        _ = model(sample_input)

    for h in hooks:
        h.remove()

    print(f'Operator Coverage Report:')
    print(f'  Total ops executed : {len(ops_executed)}')
    print(f'  Device             : {DEVICE}')
    print(f'  CPU fallbacks      : {len(data_transfers)}')

    if data_transfers:
        print('  WARN: CPU fallbacks detected:')
        for t in data_transfers[:10]:
            print(f'    {t}')
        status = 'FAIL'
    else:
        print('  PASS: All operators run on', DEVICE)
        status = 'PASS'

    return status, ops_executed


status, ops = check_operator_coverage(model, dummy_input)
print(f'\nOps: {ops}')

Exercise 6.2 – ONNX Export & Numerical ParityΒΆ

This exercise validates the complete ONNX pipeline: export (PyTorch to ONNX graph), structural validation (ONNX checker verifies graph integrity), and numerical parity (ONNX Runtime inference matches PyTorch output). The export uses opset 17 with dynamic axes for batch and sequence dimensions, matching production deployment patterns. If a CUDA or ROCm execution provider is available, the test additionally compares GPU-accelerated ONNX Runtime output against the CPU reference, verifying that the hardware vendor’s ONNX Runtime execution provider produces correct results. A max absolute error exceeding 1e-4 (CPU EP) or 1e-3 (GPU EP) indicates a bug in the export or execution provider.

if not ONNX_AVAILABLE:
    print('Skipping ONNX exercises β€” install: pip install onnx onnxruntime')
else:
    import onnx
    import onnxruntime as ort
    import numpy as np

    ONNX_PATH = Path('lab06_model.onnx')

    # ---- Export ----
    cpu_model = model.cpu().eval()
    cpu_input = dummy_input.cpu()

    torch.onnx.export(
        cpu_model,
        cpu_input,
        ONNX_PATH,
        opset_version=17,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch', 1: 'seq'}, 'output': {0: 'batch'}},
        do_constant_folding=True,
    )
    print(f'ONNX model exported: {ONNX_PATH} ({ONNX_PATH.stat().st_size / 1024:.1f} KB)')

    # ---- Validate ONNX model ----
    onnx_model = onnx.load(str(ONNX_PATH))
    try:
        onnx.checker.check_model(onnx_model)
        print('ONNX graph check  : PASS')
    except onnx.checker.ValidationError as e:
        print(f'ONNX graph check  : FAIL β€” {e}')

    # ---- Numerical parity ----
    # PyTorch output
    with torch.no_grad():
        pt_out = cpu_model(cpu_input).numpy()

    # ONNX Runtime output
    sess = ort.InferenceSession(str(ONNX_PATH), providers=['CPUExecutionProvider'])
    ort_out = sess.run(None, {'input': cpu_input.numpy()})[0]

    max_err = np.abs(pt_out - ort_out).max()
    passed  = max_err < 1e-4
    print(f'ONNX parity check  : {"PASS" if passed else "FAIL"}  max_abs_err={max_err:.2e}')

    # ---- GPU Execution Provider (if CUDA available) ----
    if DEVICE == 'cuda':
        available_providers = ort.get_available_providers()
        print(f'\nAvailable ORT providers: {available_providers}')

        if 'CUDAExecutionProvider' in available_providers:
            sess_gpu = ort.InferenceSession(
                str(ONNX_PATH),
                providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
            )
            ort_gpu_out = sess_gpu.run(None, {'input': cpu_input.numpy()})[0]
            gpu_err = np.abs(pt_out - ort_gpu_out).max()
            print(f'ONNX CUDA EP parity: {"PASS" if gpu_err < 1e-3 else "FAIL"}  max_err={gpu_err:.2e}')

        elif 'ROCMExecutionProvider' in available_providers:
            sess_gpu = ort.InferenceSession(
                str(ONNX_PATH),
                providers=['ROCMExecutionProvider', 'CPUExecutionProvider']
            )
            ort_gpu_out = sess_gpu.run(None, {'input': cpu_input.numpy()})[0]
            gpu_err = np.abs(pt_out - ort_gpu_out).max()
            print(f'ONNX ROCm EP parity: {"PASS" if gpu_err < 1e-3 else "FAIL"}  max_err={gpu_err:.2e}')

    # Move model back to device
    model = model.to(DEVICE)

Exercise 6.3 – torch.compile ValidationΒΆ

torch.compile (Torch Dynamo + Inductor) can 2-4x speed up models.
AMD validates it via torch.compile(backend='inductor') or ROCm-specific backends.

Validation criteria:

  1. Compiled output matches eager output

  2. Compiled is faster than eager

  3. No silent correctness degradation

def validate_torch_compile(model, sample_input, dtype=torch.float32):
    """
    Validate torch.compile correctness and speedup.
    """
    model.eval()
    x = sample_input.to(dtype=dtype)

    # Eager reference
    with torch.no_grad():
        eager_out = model(x)

    # Compile
    try:
        compiled_model = torch.compile(model, mode='default', fullgraph=False)

        # Warmup (compilation happens here)
        print('Compiling model (first call triggers compilation)...')
        t_compile_start = time.perf_counter()
        with torch.no_grad():
            for _ in range(3):
                compiled_out = compiled_model(x)
        if DEVICE == 'cuda':
            torch.cuda.synchronize()
        compile_time = time.perf_counter() - t_compile_start

        print(f'Compilation + 3 warmup iters: {compile_time:.2f}s')

        # Correctness check
        max_err = (compiled_out.float() - eager_out.float()).abs().max().item()
        passed  = max_err < 1e-3
        print(f'Correctness: {"PASS" if passed else "FAIL"}  max_err={max_err:.2e}')

        # Speed comparison
        iters = 50 if DEVICE != 'cpu' else 10

        def bench(fn, n=iters):
            with torch.no_grad():
                for _ in range(5): fn(x)
            if DEVICE == 'cuda': torch.cuda.synchronize()
            t0 = time.perf_counter()
            with torch.no_grad():
                for _ in range(n): fn(x)
            if DEVICE == 'cuda': torch.cuda.synchronize()
            return (time.perf_counter() - t0) / n * 1000  # ms

        eager_ms    = bench(model)
        compiled_ms = bench(compiled_model)
        speedup     = eager_ms / compiled_ms

        print(f'Eager    : {eager_ms:.3f} ms')
        print(f'Compiled : {compiled_ms:.3f} ms')
        print(f'Speedup  : {speedup:.2f}x  ({"Good" if speedup >= 1.2 else "Marginal" if speedup >= 1.0 else "REGRESSION"})')

        return {
            'correctness': 'PASS' if passed else 'FAIL',
            'max_err': round(max_err, 8),
            'eager_ms': round(eager_ms, 3),
            'compiled_ms': round(compiled_ms, 3),
            'speedup': round(speedup, 2),
        }

    except Exception as e:
        print(f'torch.compile failed: {e}')
        print('This may happen on older PyTorch or unsupported backends')
        return {'correctness': 'SKIP', 'error': str(e)}


model = model.to(DEVICE)
compile_result = validate_torch_compile(model, dummy_input)

Exercise 6.4 – Execution Mode Consistency MatrixΒΆ

AMD validation ensures: Eager = Compiled = ONNX (within tolerance).
Any mode that produces different results is a bug.

def execution_mode_consistency(model, sample_input):
    """
    Compare: Eager FP32 vs Eager FP16 vs Compiled vs ONNX
    Reference: Eager FP32 (highest precision)
    """
    cpu_model = model.cpu().eval()
    x_cpu_f32 = sample_input.cpu().float()

    results = {}

    # 1. Eager FP32 (reference)
    with torch.no_grad():
        ref = cpu_model(x_cpu_f32)
    results['Eager FP32'] = {'output': ref.clone(), 'err': 0.0}
    print(f'[REF ] Eager FP32     : output shape {ref.shape}')

    # 2. Eager FP16
    if DEVICE == 'cuda':
        gpu_model_f16 = model.half().to(DEVICE).eval()
        x_f16 = x_cpu_f32.half().to(DEVICE)
        with torch.no_grad():
            out_f16 = gpu_model_f16(x_f16).float().cpu()
        err = (out_f16 - ref).abs().max().item()
        results['Eager FP16'] = {'output': out_f16, 'err': err}
        model = model.float().to(DEVICE)  # restore
        print(f'[{"PASS" if err < 0.1 else "FAIL"}] Eager FP16     : max_err={err:.2e}')

    # 3. torch.compile
    try:
        compiled = torch.compile(cpu_model.eval(), mode='default')
        with torch.no_grad():
            for _ in range(3): out_c = compiled(x_cpu_f32)  # warmup
        err = (out_c - ref).abs().max().item()
        results['torch.compile'] = {'output': out_c, 'err': err}
        print(f'[{"PASS" if err < 1e-3 else "FAIL"}] torch.compile  : max_err={err:.2e}')
    except Exception as e:
        print(f'[SKIP] torch.compile  : {e}')

    # 4. ONNX Runtime
    if ONNX_AVAILABLE and ONNX_PATH.exists():
        sess = ort.InferenceSession(str(ONNX_PATH), providers=['CPUExecutionProvider'])
        ort_out_np = sess.run(None, {'input': x_cpu_f32.numpy()})[0]
        ort_out = torch.tensor(ort_out_np)
        err = (ort_out - ref).abs().max().item()
        results['ONNX Runtime'] = {'output': ort_out, 'err': err}
        print(f'[{"PASS" if err < 1e-3 else "FAIL"}] ONNX Runtime   : max_err={err:.2e}')
    else:
        print('[SKIP] ONNX Runtime   : onnx not installed')

    print('\nConsistency summary:')
    print(f"  {'Mode':20s}  {'Max Error vs FP32':20s}  Pass")
    print('  ' + '-' * 50)
    for mode, r in results.items():
        tol  = 1e-4 if 'FP32' in mode else (0.1 if 'FP16' in mode else 1e-3)
        ok   = r['err'] <= tol
        print(f'  {mode:20s}  {r["err"]:20.2e}  {"YES" if ok else "NO"}')

    return results


model = model.float().to(DEVICE)
consistency_results = execution_mode_consistency(model, dummy_input)

Exercise 6.5 – Opset Compatibility (ONNX)ΒΆ

AMD validates ONNX model exports across multiple opset versions.
Newer ops (like FlashAttention) require opset 17+.

if ONNX_AVAILABLE:
    opset_results = []
    test_opsets = [11, 13, 15, 17, 18]

    print('ONNX Opset Compatibility Matrix:')
    print(f"{'Opset':>8}  {'Export':>8}  {'Validate':>10}  {'Parity':>8}  Notes")
    print('-' * 55)

    cpu_model = model.cpu().eval()
    x_np = dummy_input.cpu().numpy()

    with torch.no_grad():
        ref_out = cpu_model(dummy_input.cpu()).numpy()

    for opset in test_opsets:
        path = Path(f'lab06_opset{opset}.onnx')
        try:
            torch.onnx.export(
                cpu_model, dummy_input.cpu(), path,
                opset_version=opset,
                input_names=['input'], output_names=['output'],
            )
            export_ok = True
        except Exception as e:
            export_status = 'FAIL'
            validate_status = 'N/A'
            parity_status = 'N/A'
            print(f'{opset:>8}  {export_status:>8}  {validate_status:>10}  {parity_status:>8}  {str(e)[:50]}')
            opset_results.append({'opset': opset, 'status': 'EXPORT_FAIL'})
            continue

        try:
            onnx.checker.check_model(onnx.load(str(path)))
            validate_ok = True
        except Exception:
            validate_ok = False

        try:
            sess = ort.InferenceSession(str(path), providers=['CPUExecutionProvider'])
            out  = sess.run(None, {'input': x_np})[0]
            import numpy as np
            err = np.abs(out - ref_out).max()
            parity_ok = err < 1e-4
            parity_str = f'{err:.1e}'
        except Exception as e:
            parity_ok = False
            parity_str = f'ERR'

        status = 'PASS' if (export_ok and validate_ok and parity_ok) else 'WARN'
        print(f'{opset:>8}  {"OK" if export_ok else "FAIL":>8}  '
              f'{"OK" if validate_ok else "FAIL":>10}  {parity_str:>8}  [{status}]')
        opset_results.append({'opset': opset, 'status': status, 'parity_err': parity_str})

        # Clean up
        path.unlink(missing_ok=True)

    model = model.to(DEVICE)
else:
    print('Skipping ONNX opset test β€” onnx not installed')

Exercise 6.6 – Framework Validation ReportΒΆ

The framework validation report consolidates all test results – operator coverage, torch.compile correctness and speedup, execution mode consistency, and ONNX opset compatibility – into a structured JSON document. This report format is designed for automated CI/CD consumption: a post-processing script can parse the JSON, check all status fields, and fail the build pipeline if any test shows a FAIL status. The report includes the full software environment (PyTorch version, ONNX availability, device type) for reproducibility and cross-version comparison.

report = {
    'report_type': 'Framework Validation',
    'generated_at': datetime.now().isoformat(),
    'device': DEVICE,
    'pytorch_version': torch.__version__,
    'onnx_available': ONNX_AVAILABLE,
    'tests': {
        'operator_coverage': status,
        'torch_compile': compile_result,
        'execution_mode_consistency': {
            mode: {'max_err': r['err'], 'passed': r['err'] < 0.1}
            for mode, r in consistency_results.items()
        },
    }
}

if ONNX_AVAILABLE:
    report['tests']['onnx_opset_matrix'] = opset_results if 'opset_results' in dir() else 'not run'

path = Path('lab06_framework_report.json')
path.write_text(json.dumps(report, indent=2))
print(json.dumps(report, indent=2))
print(f'\nReport saved: {path}')

SummaryΒΆ

Topic

AMD Validation Relevance

Operator coverage

ROCm has gaps vs CUDA β€” every fallback = performance bug

ONNX export + parity

AMD validates ONNX Runtime ROCm/CUDA EP

torch.compile

AMD validates Inductor + ROCm backend correctness

Execution mode consistency

Eager = Compiled = ONNX is a release blocker

Opset matrix

Newer opsets needed for modern models (LLMs need >=17)

Complete Lab IndexΒΆ

Lab

Focus

AMD JD Mapping

lab_01

Power Β· Thermals Β· Memory Β· Stability

ROCm tools, hardware soak

lab_02

GEMM Β· Attention Β· Tolerances

Kernel correctness, roofline

lab_03

Profiling Β· Precision Β· LLM Phases

Model-level validation

lab_04

Baselines Β· Version Matrix Β· CI/CD

System test plans, compatibility suites

lab_05

NCCL/RCCL Β· AllReduce Β· Scaling

Datacenter/cluster validation

lab_06

PyTorch Β· ONNX Β· torch.compile

Framework integration testing

lab_07

GPGPU Backends Β· OpenCL Β· Vulkan Β· SYCL

Cross-backend validation

lab_08

Benchmarking Β· TTFT Β· SLOs Β· Eval

Industry benchmarks, capacity planning

Previous: lab_05_distributed_training.ipynb
Next: lab_07_gpgpu_backends.ipynb β€” GPGPU backends: CoreML, DirectML, Vulkan
Back to Overview: README.md