We fine-tuned Llama 405B on AMD GPUs
Project Overview & Goals
- Team fine-tuned Llama 3.1 405B on 8× AMD MI300X using JAX instead of PyTorch.
- Strategy: port models to JAX, rely on XLA’s hardware-agnostic HLO graphs to target non-NVIDIA backends (TPUs, AMD, Trainium) with minimal code changes.
- Same JAX Llama3 implementation reportedly runs on both TPUs and AMD GPUs without modification.
JAX vs PyTorch on Non-NVIDIA Hardware
- Claim: PyTorch is tightly coupled to NVIDIA (e.g., CUDA-specific paths, opaque performance differences), making non-NVIDIA support painful.
- Counterpoints:
- Multiple users report PyTorch+ROCm working “out of the box” on AMD, including torch.compile, and existing LLM training stacks (axolotl, torchtune, LLaMA-Factory).
- Some note that functions like
scaled_dot_product_attentionhave non-CUDA implementations and also work on TPUs.
- Debate on “transparency”: whether PyTorch’s explicit CUDA-optimized kernels or XLA’s pattern-matching to vendor libraries is more transparent or portable remains contested.
Performance & MFU
- Initial run: ~35 tokens/s, batch size 16, seq length 64, eager JAX, no gradient accumulation.
- Several commenters calculate ~0.5–0.8% Model FLOPs Utilization (MFU), calling this “quite bad” and far from typical 30–40% for well-tuned training.
- Author attributes this to lack of JIT compilation and other optimizations; plans to re-run with JIT, gradient accumulation, activation checkpointing, and better sharding.
- Some mention significantly better MI300X performance elsewhere, including claims of MI300X being competitive with or faster than H100, but details are sparse.
Correctness & Porting Concerns
- Concern: framework-level differences and floating-point nuances can cause subtle accuracy drift when porting from PyTorch to JAX.
- Response: current validation uses an AI-based comparison tool; more robust testing infra is planned.
- Note that cross-framework model translations (e.g., Gemma JAX→PyTorch) are common, but debugging mismatches can be very hard.
AMD Ecosystem: ROCm, Vulkan, Consumer Cards
- Mixed experiences with ROCm:
- Some report smooth installs via official ROCm PyTorch/JAX containers.
- Others struggle with distro/kernel compatibility, driver versions, and short support windows for consumer cards.
- For inference, several suggest Vulkan backends (e.g., llama.cpp, IREE) as easier to set up; however, benchmarked results show ROCm can be 40–200% faster than Vulkan on a 7900 XTX.
- Consumer AMD GPUs (e.g., 7900 XTX) can reach roughly RTX 3090-like performance in some LLM benchmarks; 4090 often significantly faster but with cost/VRAM tradeoffs.
Cost, Hardware Choice & Availability
- Opinions on performance-per-dollar: one view ranks TPU > AMD > NVIDIA; others dispute that TPUs are “slow,” pointing to Google’s own training workloads.
- Some cloud providers and smaller vendors rent 8× MI300X nodes; bare-metal ownership requires buying 8-GPU servers with substantial power and cooling needs.
- There is interest in multi-node MI300X scaling; reports suggest it is challenging but improving, though concrete public benchmarks remain limited.
Tooling & Alternatives
- JAX’s Pallas subsystem and example Flash Attention implementation are mentioned as promising for performance.
- Tinygrad is seen as interesting but immature compared to JAX.
- Some suggest existing JAX Llama implementations like MaxText as baselines for better MFU.