Utilizing Low-precision datatypes in PyTorch and Beyond

The video features Dus from Meta and Jeff from AMD discussing the adoption of low-precision floating-point formats like FP8, FP6, and FP4 in PyTorch, highlighting scaling techniques such as per-tensor, per-row, and per-block scaling to preserve accuracy when converting from high-precision tensors. They also cover hardware support on AMD GPUs, developer tools like composable kernel and Triton, and emphasize AMD’s commitment to open-source collaboration to advance low-precision computing in the AI ecosystem.

The video begins with Dus from Meta providing an overview of utilizing new low-precision floating-point formats in PyTorch, particularly focusing on FP8, FP6, and FP4 types supported on modern hardware like MI300 GPUs. He explains that directly converting high-precision FP32 tensors to very low precision formats like FP4 results in significant information loss due to the drastic reduction in representable values. To mitigate this, a critical technique called scaling is employed, where the high-precision tensor is represented as a low-precision tensor combined with scale factors. These scale factors map the dynamic range of the original tensor onto the limited range of the low-precision format, preserving accuracy.

Dus then describes popular scaling strategies such as per-tensor scaling and per-row scaling. Per-tensor scaling applies a single scale factor across the entire tensor by calculating the maximum absolute value within the tensor and normalizing accordingly. Per-row scaling, on the other hand, calculates individual scale factors for each row, reducing the impact of outliers and lowering reconstruction loss. He also introduces more granular per-block scaling used in MXFP formats, where rows are split into smaller blocks (size 32), and scale factors are stored in a new data type called E8M0, which represents powers of two. These advanced scaling methods further improve accuracy and are supported in PyTorch on AMD hardware.

Jeff from AMD then takes over to discuss the ecosystem and hardware support for these low-precision formats. He highlights the evolution of FP8 standards, noting that the Open Compute Project (OCP) FP8 format has become the industry standard, supported on AMD’s latest GPUs like RDNA4 and CDNA4 (MI355). Jeff emphasizes the availability of comprehensive documentation and tools for developers, including intrinsics and helper functions for writing HIP kernels that leverage these low-precision types. He also introduces composable kernel (CK), a powerful framework for writing high-performance GPU kernels with many examples and tuning capabilities, which has been integrated into PyTorch to optimize models like LLaMA.

Jeff further discusses Triton, a Python-like language for writing GPU kernels that simplifies kernel development and offers competitive performance compared to native HIP kernels. He shares success stories such as the GPT-OSS model, which uses MXFP4 weights and BF16 activations, demonstrating that even without dedicated FP4 matrix cores, significant benefits in space savings and performance can be achieved. He also mentions the MI355 launch, where FP4 matrix core multiplication is used extensively, with CK and Triton implementations providing optimized kernels depending on whether the workload is compute-bound or memory-bound.

Finally, Jeff encourages the community to engage with these new tools and APIs, noting that the scaled matrix multiplication API in PyTorch is evolving and will be included in upcoming releases. He stresses AMD’s commitment to open source and collaboration, inviting developers to contribute feedback and help improve documentation and support. By embracing this collaborative ethos, AMD aims to advance the ecosystem together with the developer community, ensuring broad access to cutting-edge low-precision computing capabilities in PyTorch and beyond.