On training binary neural networks

I’ve been thinking about binary neural networks lately, and specifically how to do it without computing full precision gradients.

Broadly, there are two camps in binary neural networks. The most prominent (and effective) way to train binary neural networks is to keep the weights in full precision and binarise during forward pass. I’d say BitNet and ReactNet falls under that camp. The problem is of course that you are not taking advantage of the binary weights during the backward pass. Most binary NNs are motivated by the need for fast inference, not training, so full precision backward passes are not generally considered a problem. In fact, keeping latent full-precision weights (i.e non-binary weights) in memory during training is a tried-and-tested technique to train binary neural networks that don’t suck.

But, Shibuya et. al has been doing research on training binary neural networks without keeping a copy of the weights in full-precision in memory. They had an ICLR 2024 submission that unfortunately didn’t get accepted, and then a more recent paper titled Binary Stochastic Flip Optimization for Training Binary Neural Networks. They still use real-valued gradients, but they keep the weights binary throughout the training. The core idea is very intuitive:

Imagine a neural network with binary weights (0 and 1)

The activation function is a sign() function. This is of course not differentiable. So you use a straight-through-estimator (STE) as an approximation. Essentially, you pretend that you were using a hard-tanh as an activation function and compute the gradients as if you were using this tanh activation instead of sign()

The gradient computation proceeds normally – for a weight matrix W_l from layer l, you will have real-valued gradients. Let’s call this gradient matrix G_l

If we were using standard real-valued SGD, the weight would be updated as

W_{l}^{t} =  W_{l}^{t-1} + \eta G_{l}^{t}

Where

\eta is the learning rate, and is usually a small value like 0.001

But you want to keep the weights W_l binary. If you add real-valued gradients to binary W_l, the updated weights will be no longer binary.

The authors get around this problem by simply binarising -G_l. Positive gradient values will be replaced by 1 and negative values will be replaced by a 0

But you can’t just add the binarised gradients to the weight matrix. That would be conceptually equivalent to having a learning rate of 1! That’s too high! And since the gradient matrix is now binary, there’s no concept of a “learning rate”.

Remember, G_l is now a binary matrix, and tells you what the updated weight matrix would look like. Conceptually, if our learning rate (whatever that means – bear with me) was 1, we would just add the gradient matrix to the weight matrix. In our case, instead of adding the gradient matrix to the weights, we would just replace the weight matrix with the gradient matrix. This is the ugly side effect of using binary weights.

Let that sink in. When your gradients are binarised, you can’t “add” the gradient to the weight. The gradient value is going to be either a 1 or a 0. Your weight too is going to be either a 1 or a 0. Your decision is to either match the gradient value or not to match it. For example, if the weight is 0 and the gradient is 1, you have no option but to set the weight to 1 – because that’s what the gradient is telling you to do in order to minimise the loss!

But if we do what the binarised gradient tell us to do, we’ll be flipping our binary weights all the time. In other words, our learning rate is too high! But what does it mean to have a “low” learning rate when the gradient matrix just contains 1s and 0s?

The authors of Binary Stochastic Flip paper suggest that we use a binary mask and do an element-wise multiplication with the gradient matrix. This way, you can select a subset of the gradient matrix to be applied. In other words, instead of setting all the weights according to the gradient matrix, we update only a fraction of the (eg: 10%) of the weights to match the gradient matrix. When the learning rate is “1” (or when it is “maximum”), the mask is a tensor with every element set to 1. To user a lower learning rate, the mask matrix is created in such a way that it will only have a few ones.:

W_{l}^{t} =  \neg M \odot W_{l}^{t-1} + M \odot G_{l}^{t}

\odot represents element-wise multiplication. Multiplying the existing weight W_{l} with \neg M selects the elements from the original weight that we want to keep. Everything else is zeroed-out. Adding M \cdot G_{l}^{t} to this sets these zeroed out elements to the value in the gradient matrix G_{l}^{t}

The paper even mentions that setting the Mask randomly works rather well.

In PyTorch, I think it will look something like this:

class HyperMaskOptimizer(Optimizer):
    def __init__(self, params, delta=1e-3):
        """
        Args:
            params: Model parameters (binary weights in {0,1})
            delta: Probability for random mask (δ_t in paper)
        """
        defaults = dict(delta=delta)
        super(HyperMaskOptimizer, self).__init__(params, defaults)
        self.delta = delta

    def step(self, closure=None):
        for group in self.param_groups:
            for param in group['params']:
                grad = param.grad

                current_weights = param.data.clone()  # w_{t-1} in {0,1}
                target_weights = (-grad >= 0).float()  # w*_t in {0,1}

                mask = torch.bernoulli(torch.full_like(grad, self.delta))  
                mask_not = 1 - mask  # m̄_t (NOT operation)

                new_weights = (mask_not * current_weights) + (mask * target_weights)

                param.data.copy_(new_weights)

I wasn’t able to reproduce the numbers that the authors reported. Here’s my implementation. My 4-layer perceptron achieved only an unimpressive 33% test accuracy on MNIST after training for 200 epochs. Unfortunately the paper does not mention a reference implementation for me to cross-check with. I’ve reached out to someone whose name matches with that of the first author on LinkedIn. If I’m lucky, they’ll open source their implementation.

Meanwhile, if you can spot bugs in my snippet, please comment on the Github gist (or here) 😀

Thinking aloud: Can we speed up model training by using binary weights?

When I was at Amazon’s LLM pre-training team, our pre-training jobs used to run for weeks. Being impatient as I am, it was frustrating to wait for a whole week to see if the changes we made (mostly to the data) worked. The magic number was 1 trillion tokens, and even a smallish model (eg: 7 billion parameters) will take a few days to reach this point even with the amount of GPUs we had access to.

Now imagine you want to pre-train a model. All you have is access to one GPU, let’s say you want to train it on all of slimpajama and you want to be done with pre-training in a day. That’s roughly 600 billion tokens. Currently, this is a pipe dream. For a model to be trained on 600B tokens a day, your training throughput needs to be 7 million tokens a second. The typical training speeds you see are sub 10k tokens per second per GPU:

ModelSizeHardwareTraining speed
Gemma 3 270M270MApple M3 Pro4000 toks/second
TinyLlama1.1B16x A1001500 toks/second per GPU
Llama 1 65B65B2048x A100380 toks/second per GPU
MPT 7B7B440x A1002800 toks/second per GPU

And we are talking about 7 million tokens per second. Even if we assume 100% utilisation of the hardware (which is rare), we’ll end up with some pretty small theoretical limits to how large our model can be. Let’s try to work that out.

Target 7 million tokens/sec – how large can our model be?

To process 7 million tokens a second, the model will have to do quite a lot of floating point operations per second:

FloatcomputationsPerSecond = 7 million \times (FLOPs_{forward} + FLOPs_{backward})

But there are theoretical upper limits for the number of floating point operations the GPU can do per second.

GPUTera FLOPs per second (fp16)
RTX 5090209
Apple M42.9 – 4.3 (fp32) ~ 8 tflops for fp16
A100312
H100950 (without sparsity)

Plugging this into our equation:

(FLOPs_{forward} + FLOPs_{backward}) = \frac{FloatcomputationsPerSecond}{7 million}

Let’s assume that FLOPs for backward pass is 2.5x that of the forward pass. This is the assumption made in FlashAttention2 paper and likely holds only for self-attention, but it sounds like a reasonable assumption. Let’s also assume that every parameter in the model results in 2 FLOPs in the forward pass – every parameter in a weight matrix will be involved in one multiplication and one addition as part of matmul. Let’s ignore other layers (eg: softmax, layer norm) for the time being:

(FLOPs_{forward} + 2.5 \times FLOPs_{forward}) = \frac{FloatcomputationsPerSecond}{7 million}
3.5 \times FLOPs_{forward} = \frac{FloatcomputationsPerSecond}{7 million}
3.5 \times 2 \times parameters = \frac{FloatcomputationsPerSecond}{7 million}
parameters = \frac{FloatcomputationsPerSecond}{7 million \times 7}

If we substitute FloatComputationsPerSecond with 950, which is the theoretical best we can get (on an H100), we’ll have to limit our model to 19.3 million parameters if the training is to proceed at 7 million tokens per second. If we assume a more realistic 40% utilisation of the GPU (350 TFLOPS), we’ll be limited to 7 million parameters for our model.

During inference, we can quantise the model to improve inference speed by multiples while losing only a fraction of the accuracy. Why can’t we do that during training? Better yet, why can’t we go all the way and make all our weights binary? Yes, we will lose model performance, but it is not obvious if model training will be much faster.

But is computation even the bottleneck?

Even if your model fits in a single GPU, large distributed pre-training runs (i.e data-parallel runs) are often bottlenecked by the communication overhead. Each GPU has a copy of your model and computes its gradients. At the end of the backward pass, you must average the gradients across all GPUs. This all_reduce step can absolutely be a bottleneck.

If your model is too big to fit in a single GPU, then you must use model-parallelism to shard the model across multiple GPUs. This introduces yet another network bandwidth-heavy operation (all_gather)

You also have to read data from the disk, and in case of data-parallel runs, probably from a network-mounted drive.

So no, computation is usually not the bottleneck.

But, if your model fits in a single GPU and you are not using data-parallelism, then you have zero communication overhead and training might indeed be limited by compute.

Can’t you just use a lower precision during training?

Yes you can. The newer H100 GPUs introduced fp8 tensor cores and Nvidia’s Transformer Engine reports a 1.46x speedup when training Llama3 8B with fp8. It is harder to achieve stable training with lower precision floats or integers – the first successful fp4 pre-training is very recent.

Has anyone tried using binary weights? Not really. There’s BitNet and the variants. However, they don’t report training throughput numbers. Moreover, I wouldn’t expect training throughput to improve since BitNet maintains latent weights in full precision and the weights are binarised on-the-fly during forward pass. There’s Training Binary Neural Networks in a Binary Weight Space by Shibuya et. al, and that’s the closest work I could find that trains a binary neural network without having to compute latent weights in full precision. However, the authors do not mention how they implemented the binary weights – for all I know the weights were still fp32 but artificially confined to values of +-1. I suspect this is why table 2 in the paper reports analytical memory usage rather than experimental memory usage. It’s a pity that the paper was rejected for ICLR 2024. All the reviewers gave high “soundness” and “contribution” scores and then proceeded to criticise the experiments section. One reviewer mentioned that “the experimental results in Tab.3 seem extremely bad on medium-sized datasets such as CIFAR-10 and Tiny-ImageNet” – but that’s not the point! The results were proof that the method worked, and that should have been sufficient to justify an acceptance. Also, the “extremely bad” error rate 2.5% for the binary network vs 1.46% for the full precision baseline. The circus of academic publishing. But I digress.

Why binary weights might increase computation throughput

Let’s not even talk about memory, apart from just stating that compared to fp16, binary weights will take 16x less memory.

But why would computation be sped up? Imagine taking the dot product of two float vectors of size n. This will require n multiplications and n additions, for a total of 2n floating point operations. Now imagine your vectors were binary and the values were confined to -1 or +1. Instead of doing a dot product, you can now represent your vectors as two integers – also known as a packed integer – and do an xnor + bitcount operation to get the same result as a naive dot product. That’s just 2 operations, compared to 2n operations for a dot product, resulting in theory a speedup by a factor of n. See the figures below

Figure 1: Naive dot product resulting in n multiplications and n additions

The intuition is as follows – when all the elements in the vector are all either -1 or 1, the result of multiplication (when you do a dot product) of two elements is 1 if they are the same (i.e both are 1 or both are -1). If the elements are not the same, the result of multiplication is -1. The dot product is just a sum of these element-wise multiplications. So the final result of dot product can be seen as “the number of 1 results minus the number of -1 results in element-wise multiplication.”.

This is exactly what an XNOR (i.e an exclusive NOR) does – it outputs “1” if both the operands are 1s or 0s. If the elements are not the same (i.e 1 and 0 or 0 and 1), the output of XNOR is 0. GPUs (and CPUs) provide a single operation to count the number of set bits in an integer, called population count. In CUDA, it’s the __popc() function. The XNOR-Net paper popularised this. Thus we can re-write the dot product for binarised vectors a and b as:

a \cdot b = numberOfOnes(a \odot b)) - numberOfZeroes(a \odot b)
a \cdot b = numberOfOnes(a \odot b) - (length(a)- numberOfOnes(a \odot b))
a \cdot b = 2 \times numberOfOnes(a \odot b) - length(a)

If a and b are binary vectors of length 64, they can be represented by a single 64-bit integer each. CUDA provides a __popc() method that is a single operation that returns the number of set bits (i.e 1s) in an integer. Since length of a and b are the same – call it n – and numberOfOnes is given by the __popc() method in CUDA:

dot(a,b) = 2 x __popc(~(a^b)) – n

That is, the dot product of two vectors of length 64 takes 3 operations – a multiplication by 2, the __popc() and the addition with -n. See figure 2 below for an example. The naive dot product would have taken 128 operations – 64 multiplications and 64 additions.

Figure 2: XNOR and population count (i.e bitcount) gives the same result as the dot product, but with just two operations. See the XNOR-Net paper.

Ok, so dot products are now theoretically 40x faster. So?

Matrix multiplication is just a series of dot products. A linear layer during the forward pass simply multiplies an activation matrix with a weight matrix – that operation is now 20x faster. We are of course ignoring the fact that instead of floating point weights you just have 1s and -1s and it’s unlikely that a neural network can learn anything with binary weights – we’ll look into that later. Assuming that there’s a magic algorithm that will let us train performant binary neural networks, having pure binarised weights during training will likely give us an improvement in training throughput.

But how much improvement? Not that much, unless gradients are binary too

The rule of thumb in computing FLOPs is that each parameter results in 2 floating point operations. If you think of a neural network as one big matrix multiplication between an activation/data matrix of shape (1, k) and a weights matrix of shape (k, n), this makes sense – the total number of parameters is k x n and the number of operations in a vanilla matrix multiplication will be 2 * 1 * k * n. But with the XNOR and popcount() trick, we’ll incur only 3 operations per 64 parameters!


We had also assumed that 1 backward FLOP is equal to 2.5 forward FLOPs. But using the XNOR + Popcount() trick we discussed above, we can dramatically reduce the number of operations in the forward pass. Notice how I did not say FLOPs. That’s because XNOR + popcount are not floating point operations. Let’s not dwell on that for now. I also explicitly mentioned the forward pass – the backward pass will need full precision gradients, as mentioned in Shibuya et al. So no XNOR + popcount in the backward pass. This also means that the backward pass will have 2 x number of parameter operations, just like a non-binarised neural network. If you recall, the relationship between number of FLOPs per second (well, OPs – not necessarily floats) in a model and the OPs per second required to hit 7 million tokens per second was:

(FLOPs_{forward} + FLOPs_{backward}) = \frac{FloatcomputationsPerSecond}{7 million}

We had already expressed the FLOPs for the backward pass as 2.5 times the FLOPs for forward pass. It’s not a very accurate conversion but it will serve our purposes well. With binary neural networks, we know that 64 parameters will contribute only 3 OPs in the forward pass. So we can rewrite the above equation as:

(\frac{3}{64} \times parameters  + FLOPs_{backward}) = \frac{FloatcomputationsPerSecond}{7 million}

And the backward flops will be 2.5 times the original (i.e without the XNOR trick) forward FLOPs. But the original forward FLOPs is equal to twice the number of parameters in the model.

(\frac{3}{64} \times parameters  + 2.5 \times 2 \times parameters) = \frac{FloatcomputationsPerSecond}{7 million}

5.0468 \times parameters = \frac{FloatcomputationsPerSecond}{7 million}

parameters = \frac{FloatcomputationsPerSecond}{7 million \times 5.0468}

Replacing with FloatComputationsPerSecond with our very optimistic 350 TFLOPS estimate for the H100, the number of binary parameters our model can have is around 9.9 million. This is only around 40% larger than what the neural network would have been if we had trained it in fp16. This is really not that much, considering how much precision we are giving up.

But, what if backward pass could also use XNOR + popcount? Then our binary network could have been around 300 million parameters and still be trained at 7 million tokens per second. 300 million parameters is useful territory. But no-one has figured out a way to do that yet.

What now?

There are a few things we can do, even though the elephant in the room is the full precision gradient computations.

First, is xnor + popcount actually faster than a vanilla bf16 matrix multiplication? Even though it uses fewer operations, the CUDA tensor cores can do bf16 matrix multiplications very fast. The 40x speedup we imagined (because we were thinking in terms of the number of operations) might fizzle out to a paltry 20-30% increase, and we’d be giving up a lot of accuracy. I already have some benchmarks showing that this is indeed the case. I’ll publish them here when I get the time.

There should also be a way to make the backpropagation more efficient. Or you know, we can ditch backpropagation altogether. Launay et. al has made direct feedback alignment work for transformers. It’s definitely not going to get us the same results as backpropagation – but hey, our weights are binary, we weren’t going to be the next Mistral 7B anyway. Maybe binary neural nets + DFA is the way to build the binformer.

matmul() using PyTorch’s MPS backend is faster than Apple’s MLX

Disclaimer: I do not know why PyTorch + MPS is faster (yet)

I recently came across Apple’s MLX framework. It is a numpy-like interface that will let you run computations on Apple’s own GPUs if you have a Mac with an m-series chip. It uses Apple’s metal API under the hood. I wanted to see how much faster would it be to do matrix multiplications on apple silicon using MLX compared to Numpy + CPU, and more importantly, PyTorch. PyTorch has been supporting device="mps" for a while now.

Results

MLX + GPU (M3 Pro) is faster than Numpy + CPU. No surprises there, so I’ll omit CPU from the plots below. But MLX + GPU was surprisingly slower than PyTorch 2.6 + GPU on my MacBook with the M3 Pro chip.

I am going to ignore the more interesting phenomenon for now – that the matrix init times are slower for PyTorch when the matrix dimensions start to become huge. For the subplot titled “Matrix multiplication times”, this is not what I expected – I had expected PyTorch + MPS to be on par or slightly slower than MLX. . I still do not know why MLX is slower, but these were the hypotheses I have considered:

  • Very likely there’s something wrong with my benchmarking approach. This is the most likely hypothesis. Here’s the full notebook that I used to benchmark MLX.
  • Dtypes – Both mlx and torch arrays were explicitly set to float32, so it can’t be this.
  • GPU utilisation – I used asitop to monitor the GPU usage during the timeit runs. Both MLX and PyTorch used 100% GPU at 1380Mhz. So at least we are sure that both frameworks were using the GPU.
  • Compiled vs non-compiled MLX: Doing mx.compile(mx.matmul(a,b)) did not make a material difference in the runtimes.

Can we reproduce the same trend on a single 128×128 matrix?

Yes. Here’s a very simple test:

mlx_a = mx.random.uniform(0, 1, (128, 128), dtype=mx.float32)
mlx_b = mx.random.uniform(0, 1, (128, 128), dtype=mx.float32)

def mlx_single_matmul():
    return mx.eval(mx.matmul(mlx_a, mlx_b))

timeit.timeit(mlx_single_matmul, number=10000)

That results in 1.15 seconds. I double checked that mx.default_device() prints Device(gpu, 0) – so MLX is indeed using the GPU. For PyTorch:

torch_a = torch.rand((128, 128), device="mps", dtype=torch.float32)
torch_b = torch.rand((128, 128), device="mps", dtype=torch.float32)

def torch_single_matmul():
    return torch.matmul(torch_a, torch_b)

timeit.timeit(torch_single_matmul, number=10000)

0.21 seconds.

I changed the matrices in the above snippet to be of shape 128×30 and 30×128. Similar results. So this is not some optimisation that PyTorch has just for square matrices.

If you have ideas on why MLX is slower, or if you spot a problem with my script that would explain the discrepancy, let me know!

Back with some brains!

Its been quite some time since I posted, I admit. The good news is, I’m back. Another good news is, I found something cool!

Gone are the days when the word artificial intelligence suddenly pulls up neural networks to your mind. Actually, gone are the days of artificial intelligence it seems. Cognitive computing is going to be the norm, or at least I hope so. Simply put, cognitivie computing is mimicing the way we humans think and making computers do the same.

“But I thought that was what artificial intelligence was all about”

Yes and no.

Though cognitive computing might actually qualify as a way of implementing intelligent machines, conventional artificial intelligence was problem specific. There was usually a separate “learning phase” where we have to feed tons of data to the supposedly intelligent machine. Cognitive computing is a significant improvement on this considering that these machines can learn online. That is, there is no separate learning phase. The machine learns as it work, just like we humans do.

Numenta is a company that deals with the above said “stuff”. They have built a platform (or is it a software?) called nupic (numenta platform for intelligent computing) which implements something known as hierarchichal temporal memory (HTM). And it uses a cortical learning algorithm (CLA) to mimic human brain. Basically, the nupic functions more or less like how we do.

Enough boring theory.

Visit numenta and nupic here : http://numenta.org/

The nupic is open source and you can get the source on github :

https://github.com/numenta/nupic

A warning though. The nupic has a steep learning curve. So get your hands dirty only if you have some time and patience. I couldn’t run the tests on the build (to check if nupic installed correctly) successfully and is still asking around for solutions (the mailing list is great).

And they have some awesome videos of previous hackathons and example programs that clearly demonstrates the power of nupic. Here’s a link.