How 4 Bits Are Enough to Run Models Like ChatGPT
This post is inspired by Paulius Micikevicius’s talk on Numerics and AI at GPU Mode. Paulius leads deep learning research at NVIDIA and authored the foundational papers on mixed precision training and FP8 formats. If you want the full depth, watch the lecture.
TLDR
Neural networks spend most of their time multiplying matrices. Smaller number formats (FP16, bfloat16, FP8) mean less memory, less bandwidth, and faster math units. The challenge: gradients can be 10^6x smaller than weights, so narrow formats cause underflow. Loss scaling fixes this by multiplying gradients before they shrink to zero. bfloat16 works without scaling because it keeps FP32’s range. FP8 needs per tensor scaling and two formats (E4M3 for weights, E5M2 for gradients). For inference, INT8 and FP4 quantization map trained weights to smaller formats using calibration data to find optimal scale factors. The trend: lower precision with smarter scaling. Neural networks need dynamic range more than precision.
I spent years training models without thinking about how numbers are stored. FP32, FP16, bfloat16, FP8. These felt like implementation details that PyTorch or TensorFlow handled for me. Then I started working on inference optimization and realized: the way we represent numbers is often the actual bottleneck. Not the algorithm. Not the model architecture. The numbers themselves.
Why does precision matter for training?
A neural network is just matrix multiplications, additions, and nonlinearities applied billions of times. Every single one of those operations uses numbers. If your number format is 32 bits, you need 32 bits of memory per value. You need 32 bits of bandwidth to move it. You need 32 bit math units to compute with it.
Cut that to 16 bits and you halve memory. Halve bandwidth. And modern GPUs have dedicated 16 bit math units (Tensor Cores) that run 2x faster than their 32 bit counterparts. Cut to 8 bits and you halve again.
The catch: smaller numbers mean less precision and less range. Use too small a format and your gradients disappear. Your activations overflow. Training diverges.
The history of deep learning numerics is the history of figuring out how small we can go without breaking things.
What is a floating point number anyway?
A 32 bit float (FP32) has three parts:
- Sign bit: 1 bit that says positive or negative
- Exponent: 8 bits that determine the magnitude (the “scale”)
- Mantissa: 23 bits that determine the precision (the “details”)
The value equals: (−1)^sign × mantissa × 2^exponent
More exponent bits give you wider range. More mantissa bits give you finer precision. With only so many bits to work with, every format makes a tradeoff.
FP32 has enormous range (roughly 10^−38 to 10^38) and high precision. Overkill for most neural network operations. The question became: where can we cut?
Why did FP16 training fail initially?
IEEE FP16 has 5 exponent bits and 10 mantissa bits. Its range spans roughly 6×10^−5 to 65,504. That range is too narrow.
Here’s the problem: neural network gradients vary wildly in magnitude. Early in training, loss values might be in the hundreds. Gradients flowing back through dozens of layers can shrink by factors of 10^−6 or more. Values smaller than 6×10^−5 become zero in FP16. This is underflow.
When gradients underflow to zero, weights stop updating. Training stalls. The network learns nothing.
Early attempts at FP16 training failed for exactly this reason. Gradients vanished.
How does loss scaling fix underflow?
The fix is elegant. Before backpropagation, multiply the loss by a large number (say 1024 or 65536). This scales up all the gradients by that same factor. Small gradients that would have underflowed now stay in the representable range.
After backpropagation, divide the gradients by the same scale factor before updating weights. The weight update sees the correct gradient values.
This is loss scaling. The 2017 Mixed Precision Training paper by Micikevicius et al. showed this works across CNNs, RNNs, GANs, and language models.
But picking the right scale factor is tricky. Too small and gradients still underflow. Too large and gradients overflow (become infinity). The solution: dynamic loss scaling.
Dynamic loss scaling starts with a large scale (like 2^24). Every iteration, it checks for infinities or NaNs in the gradients. If found, it skips the weight update and reduces the scale. If not found, it periodically tries increasing the scale. The algorithm automatically finds the largest usable scale.
What is the master weight copy?
There’s another problem with FP16 training. Weight updates are tiny. A typical learning rate is 0.001 or smaller. Multiply that by a gradient and you get a very small number to add to each weight.
If your weight is 1.0 and your update is 0.00001, adding them in FP16 might give you 1.0 again. The update is too small to change the FP16 representation. This is the weight update problem.
The solution: keep a master copy of weights in FP32. Every iteration:
- Copy weights from FP32 to FP16
- Run forward and backward passes in FP16
- Accumulate the FP16 gradients into the FP32 master weights
The FP32 master weights can accumulate tiny updates over many iterations. The FP16 copy is just for fast computation.
Why did bfloat16 appear?
Google introduced bfloat16 for TPUs. It has 8 exponent bits (same as FP32) and 7 mantissa bits.
The key insight: range matters more than precision for neural networks. With 8 exponent bits, bfloat16 has the same range as FP32 (roughly 10^−38 to 10^38). Gradients that would underflow in FP16 stay representable in bfloat16.
The tradeoff is precision. 7 mantissa bits versus FP16’s 10 means coarser values. But neural networks are remarkably tolerant of noise. The slight precision loss rarely hurts accuracy.
bfloat16 often works without loss scaling. The range is wide enough that most gradients stay representable. This simplifies the training loop.
What is FP8 and why do we need two formats?
FP8 cuts to 8 bits total. NVIDIA, ARM, and Intel jointly proposed two FP8 formats in 2022:
E4M3: 4 exponent bits, 3 mantissa bits. Range up to ±448.
E5M2: 5 exponent bits, 2 mantissa bits. Range up to ±57,344.
Why two formats? Different tensors have different needs.
Weights and activations in the forward pass benefit from precision. Their values cluster in a relatively narrow range. E4M3’s extra mantissa bit helps here.
Gradients in the backward pass have extreme dynamic range. Some gradients are 10^6 times larger than others. E5M2’s wider range prevents underflow.
The recommended approach: use E4M3 for weights and activations, E5M2 for gradients. Some networks can use E4M3 everywhere. Some require both formats.
How does per tensor scaling work in FP8?
FP8’s narrow range means most tensors need scaling. The general approach:
- Find the maximum absolute value in the tensor
- Choose a scale factor that maps this maximum to the FP8 representable range
- Multiply all values by this scale before converting to FP8
- Store the scale factor alongside the tensor
- After FP8 computation, divide results by the scale to recover true magnitudes
This per tensor scaling is more involved than FP16’s single loss scale. Each tensor might need its own scale factor. The scaling factors themselves are stored in higher precision (typically FP32).
Modern hardware and libraries like NVIDIA’s Transformer Engine handle this automatically. But understanding it matters for debugging and optimization.
What about INT8?
INT8 (8 bit integers) has been popular for inference quantization. Why not use it for training?
Integers have no exponent. All 8 bits go to the mantissa (plus sign). This gives high precision within a fixed range but zero flexibility in that range.
Neural network gradients vary by factors of 10^6 or more. INT8 cannot represent this variation with a single scale factor. You would need per layer or even per channel scaling, dramatically increasing complexity.
Floating point formats handle dynamic range naturally. Each value carries its own scale in the exponent bits. This is why FP8 works for training while INT8 remains primarily an inference format.
What do Tensor Cores require?
NVIDIA Tensor Cores accelerate matrix multiplications in reduced precision. To trigger them, your tensor dimensions must meet alignment requirements.
For Hopper architecture with FP8:
- Matrix dimensions M, N, K should be multiples of 16
- Tensors should be contiguous in memory
- Specific memory layouts may be required
If your tensors don’t meet these requirements, the computation falls back to slower CUDA cores. A 768×768 matrix multiplication runs on Tensor Cores. A 765×765 might not.
This is why padding to nice dimensions often improves performance dramatically. The math is the same but the hardware utilization is completely different.
What about the accumulator precision?
When multiplying two FP8 numbers, the intermediate results are computed in higher precision. On Hopper GPUs, FP8 matrix multiplications accumulate into FP32.
This matters because matrix multiplication involves many additions. Adding two FP8 numbers would lose precision rapidly. By accumulating in FP32, we maintain accuracy even while inputs and outputs are FP8.
The outputs are then cast back to FP8 (or FP16 or bfloat16) for storage. This pattern (low precision for compute and storage, high precision for accumulation) is central to mixed precision training.
How much speedup does lower precision provide?
On NVIDIA H100 GPUs:
- FP32 Tensor Core throughput: 989 TFLOPS
- FP16/bfloat16 Tensor Core throughput: 1,979 TFLOPS (2x)
- FP8 Tensor Core throughput: 3,958 TFLOPS (4x)
Memory bandwidth follows similar ratios. Moving FP8 tensors requires half the bandwidth of FP16, quarter the bandwidth of FP32.
For memory bound operations (which dominate inference), precision reduction translates almost directly to speedup. For compute bound operations (like large matrix multiplications in training), the Tensor Core throughput increase is the limiting factor.
What is MXFP8?
Microscaling (MX) formats take scaling further. Instead of one scale per tensor, MXFP8 uses one scale per block of 32 values.
Each block of 32 FP8 values shares a single FP8 E8M0 scale factor. This allows finer grained adaptation to local value distributions within a tensor.
The benefit: MXFP8 can use E4M3 everywhere, including for gradients. The per block scaling captures dynamic range that would otherwise require E5M2.
The cost: more scale factors to store and manage. But the overhead is modest (one 8 bit scale per 32 values) and the hardware support on Blackwell architecture makes it efficient.
How do I actually use mixed precision training?
In PyTorch with AMP (Automatic Mixed Precision):
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
autocast automatically runs operations in FP16 where safe and FP32 where necessary. GradScaler handles dynamic loss scaling. The master weights in FP32 are managed automatically by the optimizer.
For FP8 training, NVIDIA’s Transformer Engine provides similar abstractions:
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(input)
What breaks with lower precision?
Some operations are numerically sensitive:
- Softmax: exponentials can overflow. Compute in FP32.
- Layer normalization: variance calculation accumulates squared values. Accumulate in FP32.
- Loss functions: cross entropy involves log of small probabilities. Compute in FP32.
- Optimizer state: Adam’s moving averages must stay in FP32.
AMP and Transformer Engine maintain lists of which operations run in which precision. You can customize these lists if your model has unusual numerical requirements.
What does the future look like?
The trend is clear: lower precision with smarter scaling. FP4 (4 bit floating point) is already appearing for inference. NVIDIA’s Blackwell architecture supports NVFP4 with per block scaling.
For training, the constraint is gradient representation. As scaling techniques improve (microscaling, per layer adaptive scaling), we may see FP4 training become viable.
The fundamental insight remains: neural networks need dynamic range more than precision. Every generation of hardware and algorithms exploits this more aggressively.
Understanding numerics lets you make informed tradeoffs. When your training diverges, you’ll know to check for gradient underflow. When your inference is slow, you’ll know which precision reductions are safe. The numbers matter.
How are models converted to INT8 or FP4 for inference?
Training happens in FP32 or FP16. Inference deployment often needs INT8 or FP4. The conversion process is called quantization. There are two main approaches.
Post Training Quantization (PTQ)
PTQ converts a trained model without retraining. The process:
- Run a small calibration dataset through the network
- Record the range of values at each layer (min, max, or percentiles)
- Compute scale factors that map these ranges to the target format
- Convert weights and set up activation quantization
For INT8, you need a scale factor and zero point per tensor (or per channel). The scale maps your observed range to the 256 integer values. Values outside the range get clipped.
The calibration dataset matters. It should represent real inference inputs. If calibration sees values from 0 to 100 but inference sees 0 to 200, half your values clip to the maximum and accuracy collapses.
Percentile calibration helps. Instead of using the absolute min/max, use the 99.9th percentile. This clips rare outliers but preserves the bulk of the distribution.
Quantization Aware Training (QAT)
Some models resist PTQ. Accuracy drops too much. QAT fixes this by simulating quantization during training.
The forward pass inserts fake quantization operations. Weights are quantized and dequantized before each layer. Activations get the same treatment. The network sees quantized values during training.
The backward pass uses the straight through estimator. Gradients flow through the fake quantization as if it were an identity function. The network learns weights that are robust to quantization noise.
QAT typically recovers accuracy that PTQ loses. The cost is retraining (or fine tuning) the model.
Why FP8 simplifies inference
If you train in FP8, inference is straightforward. The model already uses FP8 weights and activations. No conversion needed. No calibration needed. Deploy directly.
This is a major advantage of FP8 training over FP32/FP16 training with INT8 inference. The formats match. No accuracy surprises at deployment.
FP4 quantization specifics
FP4 has only 16 possible values (4 bits). Representing weights accurately requires block scaling. NVIDIA’s NVFP4 uses one FP8 scale factor per block of 16 values.
The process:
- Group weights into blocks of 16
- Find the maximum absolute value in each block
- Compute a scale that maps this maximum to FP4’s range
- Quantize each value in the block using that scale
Per block scaling captures local variation. One block might have values near zero, another might have values near one. Each gets its own scale.
For activations, the scaling must happen at runtime. The calibration approach records typical activation ranges, then applies appropriate scales during inference.
When does quantization fail?
Some layers are sensitive:
- First and last layers: They directly touch input/output. Errors here propagate everywhere.
- Attention layers: Small numerical differences in attention scores get amplified by softmax.
- Narrow layers: Layers with few channels have less redundancy to absorb quantization noise.
The fix: mixed precision quantization. Keep sensitive layers in FP16. Quantize the rest to INT8 or FP4. Tools like TensorRT and vLLM do this automatically based on sensitivity analysis.