Thinking aloud: Can we speed up model training by using binary weights?

When I was at Amazon’s LLM pre-training team, our pre-training jobs used to run for weeks. Being impatient as I am, it was frustrating to wait for a whole week to see if the changes we made (mostly to the data) worked. The magic number was 1 trillion tokens, and even a smallish model (eg: 7 billion parameters) will take a few days to reach this point even with the amount of GPUs we had access to.

Now imagine you want to pre-train a model. All you have is access to one GPU, let’s say you want to train it on all of slimpajama and you want to be done with pre-training in a day. That’s roughly 600 billion tokens. Currently, this is a pipe dream. For a model to be trained on 600B tokens a day, your training throughput needs to be 7 million tokens a second. The typical training speeds you see are sub 10k tokens per second per GPU:

ModelSizeHardwareTraining speed
Gemma 3 270M270MApple M3 Pro4000 toks/second
TinyLlama1.1B16x A1001500 toks/second per GPU
Llama 1 65B65B2048x A100380 toks/second per GPU
MPT 7B7B440x A1002800 toks/second per GPU

And we are talking about 7 million tokens per second. Even if we assume 100% utilisation of the hardware (which is rare), we’ll end up with some pretty small theoretical limits to how large our model can be. Let’s try to work that out.

Target 7 million tokens/sec – how large can our model be?

To process 7 million tokens a second, the model will have to do quite a lot of floating point operations per second:

FloatcomputationsPerSecond = 7 million \times (FLOPs_{forward} + FLOPs_{backward})

But there are theoretical upper limits for the number of floating point operations the GPU can do per second.

GPUTera FLOPs per second (fp16)
RTX 5090209
Apple M42.9 – 4.3 (fp32) ~ 8 tflops for fp16
A100312
H100950 (without sparsity)

Plugging this into our equation:

(FLOPs_{forward} + FLOPs_{backward}) = \frac{FloatcomputationsPerSecond}{7 million}

Let’s assume that FLOPs for backward pass is 2.5x that of the forward pass. This is the assumption made in FlashAttention2 paper and likely holds only for self-attention, but it sounds like a reasonable assumption. Let’s also assume that every parameter in the model results in 2 FLOPs in the forward pass – every parameter in a weight matrix will be involved in one multiplication and one addition as part of matmul. Let’s ignore other layers (eg: softmax, layer norm) for the time being:

(FLOPs_{forward} + 2.5 \times FLOPs_{forward}) = \frac{FloatcomputationsPerSecond}{7 million}
3.5 \times FLOPs_{forward} = \frac{FloatcomputationsPerSecond}{7 million}
3.5 \times 2 \times parameters = \frac{FloatcomputationsPerSecond}{7 million}
parameters = \frac{FloatcomputationsPerSecond}{7 million \times 7}

If we substitute FloatComputationsPerSecond with 950, which is the theoretical best we can get (on an H100), we’ll have to limit our model to 19.3 million parameters if the training is to proceed at 7 million tokens per second. If we assume a more realistic 40% utilisation of the GPU (350 TFLOPS), we’ll be limited to 7 million parameters for our model.

During inference, we can quantise the model to improve inference speed by multiples while losing only a fraction of the accuracy. Why can’t we do that during training? Better yet, why can’t we go all the way and make all our weights binary? Yes, we will lose model performance, but it is not obvious if model training will be much faster.

But is computation even the bottleneck?

Even if your model fits in a single GPU, large distributed pre-training runs (i.e data-parallel runs) are often bottlenecked by the communication overhead. Each GPU has a copy of your model and computes its gradients. At the end of the backward pass, you must average the gradients across all GPUs. This all_reduce step can absolutely be a bottleneck.

If your model is too big to fit in a single GPU, then you must use model-parallelism to shard the model across multiple GPUs. This introduces yet another network bandwidth-heavy operation (all_gather)

You also have to read data from the disk, and in case of data-parallel runs, probably from a network-mounted drive.

So no, computation is usually not the bottleneck.

But, if your model fits in a single GPU and you are not using data-parallelism, then you have zero communication overhead and training might indeed be limited by compute.

Can’t you just use a lower precision during training?

Yes you can. The newer H100 GPUs introduced fp8 tensor cores and Nvidia’s Transformer Engine reports a 1.46x speedup when training Llama3 8B with fp8. It is harder to achieve stable training with lower precision floats or integers – the first successful fp4 pre-training is very recent.

Has anyone tried using binary weights? Not really. There’s BitNet and the variants. However, they don’t report training throughput numbers. Moreover, I wouldn’t expect training throughput to improve since BitNet maintains latent weights in full precision and the weights are binarised on-the-fly during forward pass. There’s Training Binary Neural Networks in a Binary Weight Space by Shibuya et. al, and that’s the closest work I could find that trains a binary neural network without having to compute latent weights in full precision. However, the authors do not mention how they implemented the binary weights – for all I know the weights were still fp32 but artificially confined to values of +-1. I suspect this is why table 2 in the paper reports analytical memory usage rather than experimental memory usage. It’s a pity that the paper was rejected for ICLR 2024. All the reviewers gave high “soundness” and “contribution” scores and then proceeded to criticise the experiments section. One reviewer mentioned that “the experimental results in Tab.3 seem extremely bad on medium-sized datasets such as CIFAR-10 and Tiny-ImageNet” – but that’s not the point! The results were proof that the method worked, and that should have been sufficient to justify an acceptance. Also, the “extremely bad” error rate 2.5% for the binary network vs 1.46% for the full precision baseline. The circus of academic publishing. But I digress.

Why binary weights might increase computation throughput

Let’s not even talk about memory, apart from just stating that compared to fp16, binary weights will take 16x less memory.

But why would computation be sped up? Imagine taking the dot product of two float vectors of size n. This will require n multiplications and n additions, for a total of 2n floating point operations. Now imagine your vectors were binary and the values were confined to -1 or +1. Instead of doing a dot product, you can now represent your vectors as two integers – also known as a packed integer – and do an xnor + bitcount operation to get the same result as a naive dot product. That’s just 2 operations, compared to 2n operations for a dot product, resulting in theory a speedup by a factor of n. See the figures below

Figure 1: Naive dot product resulting in n multiplications and n additions

The intuition is as follows – when all the elements in the vector are all either -1 or 1, the result of multiplication (when you do a dot product) of two elements is 1 if they are the same (i.e both are 1 or both are -1). If the elements are not the same, the result of multiplication is -1. The dot product is just a sum of these element-wise multiplications. So the final result of dot product can be seen as “the number of 1 results minus the number of -1 results in element-wise multiplication.”.

This is exactly what an XNOR (i.e an exclusive NOR) does – it outputs “1” if both the operands are 1s or 0s. If the elements are not the same (i.e 1 and 0 or 0 and 1), the output of XNOR is 0. GPUs (and CPUs) provide a single operation to count the number of set bits in an integer, called population count. In CUDA, it’s the __popc() function. The XNOR-Net paper popularised this. Thus we can re-write the dot product for binarised vectors a and b as:

a \cdot b = numberOfOnes(a \odot b)) - numberOfZeroes(a \odot b)
a \cdot b = numberOfOnes(a \odot b) - (length(a)- numberOfOnes(a \odot b))
a \cdot b = 2 \times numberOfOnes(a \odot b) - length(a)

If a and b are binary vectors of length 64, they can be represented by a single 64-bit integer each. CUDA provides a __popc() method that is a single operation that returns the number of set bits (i.e 1s) in an integer. Since length of a and b are the same – call it n – and numberOfOnes is given by the __popc() method in CUDA:

dot(a,b) = 2 x __popc(~(a^b)) – n

That is, the dot product of two vectors of length 64 takes 3 operations – a multiplication by 2, the __popc() and the addition with -n. See figure 2 below for an example. The naive dot product would have taken 128 operations – 64 multiplications and 64 additions.

Figure 2: XNOR and population count (i.e bitcount) gives the same result as the dot product, but with just two operations. See the XNOR-Net paper.

Ok, so dot products are now theoretically 40x faster. So?

Matrix multiplication is just a series of dot products. A linear layer during the forward pass simply multiplies an activation matrix with a weight matrix – that operation is now 20x faster. We are of course ignoring the fact that instead of floating point weights you just have 1s and -1s and it’s unlikely that a neural network can learn anything with binary weights – we’ll look into that later. Assuming that there’s a magic algorithm that will let us train performant binary neural networks, having pure binarised weights during training will likely give us an improvement in training throughput.

But how much improvement? Not that much, unless gradients are binary too

The rule of thumb in computing FLOPs is that each parameter results in 2 floating point operations. If you think of a neural network as one big matrix multiplication between an activation/data matrix of shape (1, k) and a weights matrix of shape (k, n), this makes sense – the total number of parameters is k x n and the number of operations in a vanilla matrix multiplication will be 2 * 1 * k * n. But with the XNOR and popcount() trick, we’ll incur only 3 operations per 64 parameters!


We had also assumed that 1 backward FLOP is equal to 2.5 forward FLOPs. But using the XNOR + Popcount() trick we discussed above, we can dramatically reduce the number of operations in the forward pass. Notice how I did not say FLOPs. That’s because XNOR + popcount are not floating point operations. Let’s not dwell on that for now. I also explicitly mentioned the forward pass – the backward pass will need full precision gradients, as mentioned in Shibuya et al. So no XNOR + popcount in the backward pass. This also means that the backward pass will have 2 x number of parameter operations, just like a non-binarised neural network. If you recall, the relationship between number of FLOPs per second (well, OPs – not necessarily floats) in a model and the OPs per second required to hit 7 million tokens per second was:

(FLOPs_{forward} + FLOPs_{backward}) = \frac{FloatcomputationsPerSecond}{7 million}

We had already expressed the FLOPs for the backward pass as 2.5 times the FLOPs for forward pass. It’s not a very accurate conversion but it will serve our purposes well. With binary neural networks, we know that 64 parameters will contribute only 3 OPs in the forward pass. So we can rewrite the above equation as:

(\frac{3}{64} \times parameters  + FLOPs_{backward}) = \frac{FloatcomputationsPerSecond}{7 million}

And the backward flops will be 2.5 times the original (i.e without the XNOR trick) forward FLOPs. But the original forward FLOPs is equal to twice the number of parameters in the model.

(\frac{3}{64} \times parameters  + 2.5 \times 2 \times parameters) = \frac{FloatcomputationsPerSecond}{7 million}

5.0468 \times parameters = \frac{FloatcomputationsPerSecond}{7 million}

parameters = \frac{FloatcomputationsPerSecond}{7 million \times 5.0468}

Replacing with FloatComputationsPerSecond with our very optimistic 350 TFLOPS estimate for the H100, the number of binary parameters our model can have is around 9.9 million. This is only around 40% larger than what the neural network would have been if we had trained it in fp16. This is really not that much, considering how much precision we are giving up.

But, what if backward pass could also use XNOR + popcount? Then our binary network could have been around 300 million parameters and still be trained at 7 million tokens per second. 300 million parameters is useful territory. But no-one has figured out a way to do that yet.

What now?

There are a few things we can do, even though the elephant in the room is the full precision gradient computations.

First, is xnor + popcount actually faster than a vanilla bf16 matrix multiplication? Even though it uses fewer operations, the CUDA tensor cores can do bf16 matrix multiplications very fast. The 40x speedup we imagined (because we were thinking in terms of the number of operations) might fizzle out to a paltry 20-30% increase, and we’d be giving up a lot of accuracy. I already have some benchmarks showing that this is indeed the case. I’ll publish them here when I get the time.

There should also be a way to make the backpropagation more efficient. Or you know, we can ditch backpropagation altogether. Launay et. al has made direct feedback alignment work for transformers. It’s definitely not going to get us the same results as backpropagation – but hey, our weights are binary, we weren’t going to be the next Mistral 7B anyway. Maybe binary neural nets + DFA is the way to build the binformer.