Section 3: Framework Validation

PyTorch, TensorFlow & ONNX Runtime on Hardware Backends

Duration: 5 hours
Difficulty: Intermediate

3.1 Why Framework Validation Matters

ML frameworks are the interface between models and hardware. Each hardware vendor must ensure their backend produces identical results to the reference implementation:

User's PyTorch Code
       
PyTorch Frontend (ATen ops)
       
Backend Dispatch (CUDA / ROCm / XLA / Neuron / QNN)
       
Hardware-Specific Kernels
       
GPU / NPU / TPU silicon

Framework validation verifies that the backend dispatch and kernel mapping produce correct results for every supported operation.

3.2 PyTorch Backend Validation

Op-Level Validation

PyTorch defines ~2000 operators (torch.* and torch.nn.functional.*). Hardware vendors must support the critical subset:

import torch

# Critical ops for AI workloads (priority order)
CRITICAL_OPS = [
    # Linear algebra
    ("matmul", lambda: torch.matmul(torch.randn(256, 256), torch.randn(256, 256))),
    ("bmm", lambda: torch.bmm(torch.randn(8, 256, 256), torch.randn(8, 256, 256))),
    ("addmm", lambda: torch.addmm(torch.randn(256), torch.randn(256, 256), torch.randn(256, 256))),

    # Activations
    ("relu", lambda: torch.relu(torch.randn(1024))),
    ("gelu", lambda: torch.nn.functional.gelu(torch.randn(1024))),
    ("silu", lambda: torch.nn.functional.silu(torch.randn(1024))),  # SwiGLU in Llama

    # Normalization
    ("layernorm", lambda: torch.nn.functional.layer_norm(torch.randn(8, 512, 768), [768])),
    ("softmax", lambda: torch.softmax(torch.randn(8, 32, 2048, 2048), dim=-1)),

    # Reduction
    ("sum", lambda: torch.randn(1024, 1024).sum()),
    ("mean", lambda: torch.randn(1024, 1024).mean()),
    ("max", lambda: torch.randn(1024, 1024).max()),

    # Elementwise
    ("add", lambda: torch.randn(1024) + torch.randn(1024)),
    ("mul", lambda: torch.randn(1024) * torch.randn(1024)),
    ("exp", lambda: torch.exp(torch.randn(1024))),
    ("log", lambda: torch.log(torch.abs(torch.randn(1024)) + 1e-6)),

    # Embedding / Indexing
    ("embedding", lambda: torch.nn.functional.embedding(
        torch.randint(0, 50257, (8, 512)), torch.randn(50257, 768))),
    ("index_select", lambda: torch.index_select(
        torch.randn(1000, 768), 0, torch.randint(0, 1000, (64,)))),

    # Convolution
    ("conv2d", lambda: torch.nn.functional.conv2d(
        torch.randn(1, 3, 224, 224), torch.randn(64, 3, 7, 7), stride=2, padding=3)),
]


def _move_to_device(tensor, device, dtype):
    """Move a tensor to the target device and dtype."""
    if tensor.is_floating_point():
        return tensor.to(device=device, dtype=dtype)
    else:
        # Integer tensors (e.g. indices for embedding) keep their dtype
        return tensor.to(device=device)


def _run_op_on_device(op_fn, device, dtype):
    """Re-execute an op by intercepting torch.randn/randint to produce device tensors."""
    original_randn = torch.randn
    original_randint = torch.randint

    def randn_on_device(*args, **kwargs):
        kwargs['device'] = device
        kwargs['dtype'] = dtype
        return original_randn(*args, **kwargs)

    def randint_on_device(*args, **kwargs):
        kwargs['device'] = device
        return original_randint(*args, **kwargs)

    torch.randn = randn_on_device
    torch.randint = randint_on_device
    try:
        return op_fn()
    finally:
        torch.randn = original_randn
        torch.randint = original_randint


def validate_ops_on_device(device='cuda', dtype=torch.float16, atol=1e-2, rtol=1e-3):
    """Run every critical op on CPU (reference) and device, compare numerically."""
    results = []
    for name, op_fn in CRITICAL_OPS:
        try:
            # CPU reference (float32)
            torch.manual_seed(42)
            ref = op_fn()

            # Device under test
            torch.manual_seed(42)
            dev_result = _run_op_on_device(op_fn, device, dtype)

            # Verify the result actually lives on the target device
            if isinstance(dev_result, torch.Tensor):
                assert dev_result.device.type == torch.device(device).type, (
                    f"Result on {dev_result.device}, expected {device}"
                )
                # Compare numerically: bring both to float32 CPU
                dev_cpu = dev_result.float().cpu()
                ref_cpu = ref.float().cpu() if ref.is_floating_point() else ref.cpu()
                max_diff = (dev_cpu - ref_cpu).abs().max().item()
                passed = max_diff < atol
                result = {
                    "op": name,
                    "status": "PASS" if passed else "FAIL",
                    "max_diff": max_diff,
                }
                if not passed:
                    result["detail"] = f"max_diff {max_diff:.6f} exceeds atol {atol}"
            elif isinstance(dev_result, tuple):
                # Ops like max() return (values, indices)
                dev_vals = dev_result[0].float().cpu()
                ref_vals = ref[0].float().cpu() if isinstance(ref, tuple) else ref.float().cpu()
                max_diff = (dev_vals - ref_vals).abs().max().item()
                passed = max_diff < atol
                result = {
                    "op": name,
                    "status": "PASS" if passed else "FAIL",
                    "max_diff": max_diff,
                }
            else:
                result = {"op": name, "status": "FAIL", "detail": "Unexpected return type"}

        except Exception as e:
            result = {"op": name, "status": "FAIL", "error": str(e)}

        results.append(result)
        status = result["status"]
        diff_str = f" max_diff={result['max_diff']:.6f}" if 'max_diff' in result else ""
        err_str = f" error={result.get('error', result.get('detail', ''))}" if status == "FAIL" else ""
        print(f"{status} [{name}]{diff_str}{err_str}")

    passed = sum(1 for r in results if r["status"] == "PASS")
    print(f"\n{passed}/{len(results)} ops passed on {device} ({dtype})")
    return results

PyTorch Model-Level Smoke Test

While op-level testing validates individual kernels, model-level smoke tests verify that the entire forward pass of a real model produces correct output on the target device. A model-level test catches integration issues that op-level tests miss: incorrect dispatch for fused operators, memory layout mismatches between layers, or unexpected dtype promotions in the compute graph. The tolerance is intentionally looser (0.1 vs 1e-3 for individual ops) because errors compound across dozens of layers. The test suite covers ResNet-50 (convolution-heavy), ViT-B/16 (attention-heavy), and EfficientNet-B0 (depthwise convolution + squeeze-excite), ensuring coverage across the major architectural patterns used in production CV models.

import torch
from torchvision import models

def validate_pytorch_model(model_name, input_shape, device='cuda',
                           dtype=torch.float16):
    """Validate a torchvision model produces consistent output."""
    model_fn = getattr(models, model_name)
    model = model_fn(weights=None).eval()

    torch.manual_seed(42)
    x = torch.randn(*input_shape)

    # CPU reference
    with torch.no_grad():
        y_ref = model(x)

    # Device under test
    model_dev = model.to(device, dtype)
    x_dev = x.to(device, dtype)
    with torch.no_grad():
        y_dev = model_dev(x_dev).float().cpu()

    max_diff = (y_dev - y_ref).abs().max().item()
    passed = max_diff < 0.1  # Model-level tolerance is looser
    print(f"{'PASS' if passed else 'FAIL'} [{model_name}] max_diff={max_diff:.6f}")
    return passed

# Smoke tests
validate_pytorch_model("resnet50", (1, 3, 224, 224))
validate_pytorch_model("vit_b_16", (1, 3, 224, 224))
validate_pytorch_model("efficientnet_b0", (1, 3, 224, 224))

PyTorch Autograd Validation (Training)

Inference (forward pass) correctness does not guarantee training correctness – the backward pass computes gradients via automatic differentiation, which exercises a different set of kernel implementations. Each forward op has a corresponding backward kernel (e.g., matmul_backward, gelu_backward, layernorm_backward), and hardware vendors must validate that gradients computed on the device match the CPU FP32 reference. The gradient tolerance is typically looser than forward tolerance because errors compound through the chain rule: each layer’s gradient depends on the product of all downstream Jacobians. A gradient error exceeding 0.05 absolute indicates a backward kernel bug that will cause training to diverge or converge to a suboptimal solution on the target hardware.

def validate_backward_pass(device='cuda', dtype=torch.float16):
    """Validate that gradients are computed correctly on device."""
    torch.manual_seed(42)

    # Simple model
    model = torch.nn.Sequential(
        torch.nn.Linear(768, 3072),
        torch.nn.GELU(),
        torch.nn.Linear(3072, 768),
        torch.nn.LayerNorm(768),
    )

    x = torch.randn(8, 512, 768, requires_grad=True)
    target = torch.randn(8, 512, 768)

    # CPU reference
    model_cpu = model
    y_cpu = model_cpu(x)
    loss_cpu = torch.nn.functional.mse_loss(y_cpu, target)
    loss_cpu.backward()
    grad_ref = x.grad.clone()

    # Device
    x.grad = None
    model_dev = model.to(device, dtype)
    x_dev = x.detach().to(device, dtype).requires_grad_(True)
    target_dev = target.to(device, dtype)
    y_dev = model_dev(x_dev)
    loss_dev = torch.nn.functional.mse_loss(y_dev, target_dev)
    loss_dev.backward()
    grad_dev = x_dev.grad.float().cpu()

    max_diff = (grad_dev - grad_ref).abs().max().item()
    print(f"Gradient max diff: {max_diff:.6f}")
    return max_diff < 0.05  # Gradient tolerance

3.3 TensorFlow Backend Validation

TF Op Validation

Hardware vendors that support TensorFlow must validate their XLA backend or custom op implementations against the CPU reference. The approach mirrors PyTorch validation: run each op on CPU (FP32 reference) and the target device (FP16), then compare outputs within dtype-appropriate tolerances. TensorFlow’s op dispatch mechanism differs from PyTorch – it uses a graph-based execution model where ops are compiled and optimized before execution – so bugs may appear during graph optimization (op fusion, constant folding) that are not present in eager-mode PyTorch. The test covers tf.matmul, tf.nn.softmax, tf.nn.relu, tf.nn.gelu, and tf.keras.layers.LayerNormalization – the core ops used in every modern transformer and CNN architecture.

import tensorflow as tf
import numpy as np

def validate_tf_op(name, op_fn, input_data, device='/GPU:0'):
    """Validate a TensorFlow op on GPU vs CPU."""
    # CPU reference
    with tf.device('/CPU:0'):
        ref = op_fn(tf.constant(input_data, dtype=tf.float32)).numpy()

    # GPU
    with tf.device(device):
        result = op_fn(tf.constant(input_data, dtype=tf.float16))
        result = tf.cast(result, tf.float32).numpy()

    max_diff = np.max(np.abs(result - ref))
    passed = max_diff < 0.01
    print(f"{'PASS' if passed else 'FAIL'} [TF {name}] max_diff={max_diff:.6f}")
    return passed

# Test common TF ops
np.random.seed(42)
data = np.random.randn(8, 512, 768).astype(np.float32)

validate_tf_op("matmul", lambda x: tf.matmul(x, tf.transpose(x, [0, 2, 1])), data)
validate_tf_op("softmax", lambda x: tf.nn.softmax(x, axis=-1), data)
validate_tf_op("relu", lambda x: tf.nn.relu(x), data)
validate_tf_op("gelu", lambda x: tf.nn.gelu(x), data)
validate_tf_op("layer_norm", lambda x: tf.keras.layers.LayerNormalization()(x), data)

TF SavedModel Validation

The SavedModel format is TensorFlow’s standard for model serialization and deployment. Validating a SavedModel on the target device ensures that the serialization-deserialization pipeline preserves numerical correctness, and that the loaded graph dispatches to GPU kernels rather than falling back to CPU. This test is particularly important for inference deployments where models are exported once (often on a different machine) and loaded repeatedly in production – any numerical divergence introduced during loading would silently affect every subsequent prediction.

def validate_savedmodel(model_path, test_input, device='/GPU:0'):
    """Load a SavedModel and validate output on GPU vs CPU."""
    model = tf.saved_model.load(model_path)

    with tf.device('/CPU:0'):
        ref = model(test_input).numpy()

    with tf.device(device):
        result = model(test_input).numpy()

    max_diff = np.max(np.abs(result - ref))
    print(f"SavedModel max diff: {max_diff:.6f}")
    return max_diff < 0.01

3.4 ONNX Runtime Validation

ONNX Runtime is critical for hardware vendors because many deploy via ONNX export.

ONNX Export and Validate

import torch
import onnx
import onnxruntime as ort
import numpy as np
import tempfile
import os

def export_and_validate_onnx(model, input_shape, model_name="model"):
    """Export PyTorch model to ONNX and validate numerical parity."""
    model.eval()
    torch.manual_seed(42)
    dummy_input = torch.randn(*input_shape)

    # Export to a portable temporary directory
    tmp_dir = tempfile.mkdtemp(prefix="onnx_validation_")
    onnx_path = os.path.join(tmp_dir, f"{model_name}.onnx")
    torch.onnx.export(
        model, dummy_input, onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
        opset_version=17,
    )

    # Validate ONNX model structure
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)

    # PyTorch reference
    with torch.no_grad():
        ref = model(dummy_input).numpy()

    # ONNX Runtime inference
    session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
    ort_result = session.run(None, {"input": dummy_input.numpy()})[0]

    max_diff = np.max(np.abs(ort_result - ref))
    passed = max_diff < 1e-4
    print(f"{'PASS' if passed else 'FAIL'} [{model_name}] "
          f"PyTorch vs ONNX max_diff={max_diff:.6f}")
    return passed


# Test with common models
model = torch.nn.Sequential(
    torch.nn.Linear(768, 3072),
    torch.nn.GELU(),
    torch.nn.Linear(3072, 768),
)
export_and_validate_onnx(model, (1, 768), "simple_mlp")

ONNX Runtime Execution Provider Comparison

ONNX Runtime supports multiple execution providers (EPs): CPU, CUDA, ROCm, TensorRT, QNN (Qualcomm), and Neuron (AWS). Each EP implements the ONNX operator set using vendor-specific kernels, and numerical differences between providers indicate implementation-level bugs. This cross-provider comparison loads the same ONNX model on every available EP and compares outputs against the CPU reference. For hardware vendors, this is a critical validation step because customers frequently switch between providers (e.g., CUDA in development, TensorRT in production) and expect identical model behavior. A max difference exceeding 1e-3 between providers is typically a reportable issue in the vendor’s bug tracker.

def compare_onnx_providers(onnx_path, test_input):
    """Compare ONNX Runtime execution providers (CPU, CUDA, ROCm, TensorRT)."""
    results = {}

    providers_to_test = [
        ('CPUExecutionProvider', {}),
        ('CUDAExecutionProvider', {}),
        # ('ROCMExecutionProvider', {}),        # AMD
        # ('TensorrtExecutionProvider', {}),     # NVIDIA TensorRT
        # ('QNNExecutionProvider', {}),           # Qualcomm
        # ('NeuronExecutionProvider', {}),        # AWS Inferentia
    ]

    for provider_name, provider_opts in providers_to_test:
        try:
            session = ort.InferenceSession(
                onnx_path,
                providers=[(provider_name, provider_opts)]
            )
            input_name = session.get_inputs()[0].name
            result = session.run(None, {input_name: test_input})[0]
            results[provider_name] = result
            print(f"  {provider_name}: OK (shape={result.shape})")
        except Exception as e:
            print(f"  {provider_name}: UNAVAILABLE ({e})")

    # Cross-compare all available providers against CPU
    if 'CPUExecutionProvider' in results:
        ref = results['CPUExecutionProvider']
        for name, result in results.items():
            if name != 'CPUExecutionProvider':
                max_diff = np.max(np.abs(result - ref))
                passed = max_diff < 1e-3
                print(f"{'PASS' if passed else 'FAIL'} "
                      f"[{name} vs CPU] max_diff={max_diff:.6f}")

ONNX Opset Coverage Testing

Every ONNX model specifies an opset version that determines which operators and semantics are available. An execution provider must support all ops used by the model at the specified opset level – any unsupported op causes the EP to fall back to CPU for that node, destroying performance. This utility lists all unique ONNX ops in a model, enabling validation engineers to cross-reference against the EP’s supported op list. Modern LLMs typically require opset 17+ for operations like RotaryEmbedding and GroupNormalization, while older CV models may work with opset 11. The opset audit is a prerequisite for any hardware vendor claiming ONNX Runtime support.

def check_opset_coverage(onnx_path):
    """Check which ops are used and whether the EP supports them."""
    model = onnx.load(onnx_path)
    ops_used = set()
    for node in model.graph.node:
        ops_used.add(node.op_type)

    print(f"Ops used in model ({len(ops_used)}):")
    for op in sorted(ops_used):
        print(f"  - {op}")

    return ops_used

3.5 Compiler & Graph Optimization Validation

torch.compile Validation

torch.compile uses the Dynamo frontend to trace PyTorch code into a graph, then the Inductor backend compiles it into optimized GPU kernels (Triton for NVIDIA/AMD). Compilation can fuse multiple ops into single kernels, reorder operations, and apply algebraic simplifications – all of which can change numerical behavior. Validation compares the compiled model’s output against the eager-mode reference to ensure the compiler optimizations preserve correctness within tolerance (typically 1e-4). For AMD, torch.compile with the ROCm/Triton backend is a key differentiator, and any correctness regression blocks the ROCm release.

def validate_torch_compile(model, input_shape, backend="inductor"):
    """Validate that torch.compile produces correct results."""
    model.eval()
    torch.manual_seed(42)
    x = torch.randn(*input_shape, device='cuda')

    # Eager reference
    with torch.no_grad():
        y_eager = model(x)

    # Compiled
    compiled_model = torch.compile(model, backend=backend)
    with torch.no_grad():
        y_compiled = compiled_model(x)

    max_diff = (y_compiled - y_eager).abs().max().item()
    passed = max_diff < 1e-4
    print(f"{'PASS' if passed else 'FAIL'} "
          f"[torch.compile/{backend}] max_diff={max_diff:.6f}")
    return passed

TensorRT Optimization Validation

TensorRT is NVIDIA’s inference optimization engine that converts ONNX models into highly optimized GPU execution plans with operator fusion, kernel auto-tuning, and optional quantization (FP16/INT8). The optimization process can introduce numerical differences because TensorRT may substitute mathematically equivalent but numerically different implementations (e.g., fusing BatchNorm into convolution weights). Validation compares the TensorRT-optimized output against the unoptimized ONNX Runtime CPU baseline. The acceptable tolerance depends on whether TensorRT was configured for FP32, FP16, or INT8 precision mode – INT8 tolerances are significantly looser because of quantization error.

def validate_tensorrt_optimization(onnx_path, test_input):
    """Validate TensorRT-optimized model against ONNX baseline."""
    # Baseline: ONNX Runtime CPU
    cpu_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
    input_name = cpu_session.get_inputs()[0].name
    ref = cpu_session.run(None, {input_name: test_input})[0]

    # TensorRT-optimized
    trt_session = ort.InferenceSession(
        onnx_path,
        providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider']
    )
    trt_result = trt_session.run(None, {input_name: test_input})[0]

    max_diff = np.max(np.abs(trt_result - ref))
    print(f"TensorRT vs CPU: max_diff={max_diff:.6f}")
    return max_diff

3.6 Framework-Specific Validation Checklists

PyTorch on ROCm (AMD)

  • All ATen ops dispatch to HIP kernels correctly

  • torch.cuda.* APIs work on ROCm (HIP-ified)

  • torch.compile with Triton backend generates valid GPU code

  • Mixed precision (torch.amp) works correctly

  • Autograd backward pass matches CPU reference

  • torch.distributed works with RCCL backend

PyTorch on CUDA (NVIDIA)

  • cuDNN dispatches the fastest algorithm for each conv config

  • cuBLAS GEMM produces correct results across all dtypes

  • FlashAttention integration is numerically correct

  • TensorFloat32 (torch.backends.cuda.matmul.allow_tf32) behavior is correct

  • CUDA graphs capture and replay correctly

ONNX Runtime on Custom Hardware

  • All required ONNX ops are implemented in the execution provider

  • Opset version compatibility (opset 15–19)

  • Dynamic shapes are handled correctly

  • Quantized model support (INT8, FP16)

  • Graph optimization passes don’t change semantics

3.7 Exercises

  1. Op Coverage Audit: Write a script that iterates through the top 50 PyTorch ops (by usage frequency in common models) and tests each on CPU vs GPU. Report a coverage matrix.

  2. ONNX Round-Trip: Export ResNet-50 to ONNX, run inference with ONNX Runtime on CPU and GPU execution providers, and compare outputs to PyTorch. What is the maximum difference?

  3. torch.compile Benchmark: Compare eager vs torch.compile execution for a transformer block. Verify correctness AND measure speedup.

  4. Mixed Precision Validation: Run a training step in FP32 and AMP (FP16/BF16) and compare gradients. How much do they diverge?

  5. Framework Parity: Run the same model (e.g., BERT-base) in PyTorch and TensorFlow with identical weights. Compare outputs — are they identical?

Key Takeaways

  • Framework validation ensures the software stack doesn’t introduce errors

  • Op-level tests catch individual kernel bugs; model-level tests catch integration issues

  • ONNX is the portable interchange format — validate your EP thoroughly

  • Compiler optimizations (torch.compile, TensorRT) can change numerical behavior

  • Always test both inference (forward) and training (forward + backward)

Previous: 02_kernel_validation.ipynb
Next: 04_model_performance_validation.ipynb Back to Overview: README.md