We fine-tuned Llama 405B on AMD GPUs

Project Overview & Goals

  • Team fine-tuned Llama 3.1 405B on 8× AMD MI300X using JAX instead of PyTorch.
  • Strategy: port models to JAX, rely on XLA’s hardware-agnostic HLO graphs to target non-NVIDIA backends (TPUs, AMD, Trainium) with minimal code changes.
  • Same JAX Llama3 implementation reportedly runs on both TPUs and AMD GPUs without modification.

JAX vs PyTorch on Non-NVIDIA Hardware

  • Claim: PyTorch is tightly coupled to NVIDIA (e.g., CUDA-specific paths, opaque performance differences), making non-NVIDIA support painful.
  • Counterpoints:
    • Multiple users report PyTorch+ROCm working “out of the box” on AMD, including torch.compile, and existing LLM training stacks (axolotl, torchtune, LLaMA-Factory).
    • Some note that functions like scaled_dot_product_attention have non-CUDA implementations and also work on TPUs.
  • Debate on “transparency”: whether PyTorch’s explicit CUDA-optimized kernels or XLA’s pattern-matching to vendor libraries is more transparent or portable remains contested.

Performance & MFU

  • Initial run: ~35 tokens/s, batch size 16, seq length 64, eager JAX, no gradient accumulation.
  • Several commenters calculate ~0.5–0.8% Model FLOPs Utilization (MFU), calling this “quite bad” and far from typical 30–40% for well-tuned training.
  • Author attributes this to lack of JIT compilation and other optimizations; plans to re-run with JIT, gradient accumulation, activation checkpointing, and better sharding.
  • Some mention significantly better MI300X performance elsewhere, including claims of MI300X being competitive with or faster than H100, but details are sparse.

Correctness & Porting Concerns

  • Concern: framework-level differences and floating-point nuances can cause subtle accuracy drift when porting from PyTorch to JAX.
  • Response: current validation uses an AI-based comparison tool; more robust testing infra is planned.
  • Note that cross-framework model translations (e.g., Gemma JAX→PyTorch) are common, but debugging mismatches can be very hard.

AMD Ecosystem: ROCm, Vulkan, Consumer Cards

  • Mixed experiences with ROCm:
    • Some report smooth installs via official ROCm PyTorch/JAX containers.
    • Others struggle with distro/kernel compatibility, driver versions, and short support windows for consumer cards.
  • For inference, several suggest Vulkan backends (e.g., llama.cpp, IREE) as easier to set up; however, benchmarked results show ROCm can be 40–200% faster than Vulkan on a 7900 XTX.
  • Consumer AMD GPUs (e.g., 7900 XTX) can reach roughly RTX 3090-like performance in some LLM benchmarks; 4090 often significantly faster but with cost/VRAM tradeoffs.

Cost, Hardware Choice & Availability

  • Opinions on performance-per-dollar: one view ranks TPU > AMD > NVIDIA; others dispute that TPUs are “slow,” pointing to Google’s own training workloads.
  • Some cloud providers and smaller vendors rent 8× MI300X nodes; bare-metal ownership requires buying 8-GPU servers with substantial power and cooling needs.
  • There is interest in multi-node MI300X scaling; reports suggest it is challenging but improving, though concrete public benchmarks remain limited.

Tooling & Alternatives

  • JAX’s Pallas subsystem and example Flash Attention implementation are mentioned as promising for performance.
  • Tinygrad is seen as interesting but immature compared to JAX.
  • Some suggest existing JAX Llama implementations like MaxText as baselines for better MFU.