matmul() using PyTorch’s MPS backend is faster than Apple’s MLX

Disclaimer: I do not know why PyTorch + MPS is faster (yet)

I recently came across Apple’s MLX framework. It is a numpy-like interface that will let you run computations on Apple’s own GPUs if you have a Mac with an m-series chip. It uses Apple’s metal API under the hood. I wanted to see how much faster would it be to do matrix multiplications on apple silicon using MLX compared to Numpy + CPU, and more importantly, PyTorch. PyTorch has been supporting device="mps" for a while now.

Results

MLX + GPU (M3 Pro) is faster than Numpy + CPU. No surprises there, so I’ll omit CPU from the plots below. But MLX + GPU was surprisingly slower than PyTorch 2.6 + GPU on my MacBook with the M3 Pro chip.

I am going to ignore the more interesting phenomenon for now – that the matrix init times are slower for PyTorch when the matrix dimensions start to become huge. For the subplot titled “Matrix multiplication times”, this is not what I expected – I had expected PyTorch + MPS to be on par or slightly slower than MLX. . I still do not know why MLX is slower, but these were the hypotheses I have considered:

  • Very likely there’s something wrong with my benchmarking approach. This is the most likely hypothesis. Here’s the full notebook that I used to benchmark MLX.
  • Dtypes – Both mlx and torch arrays were explicitly set to float32, so it can’t be this.
  • GPU utilisation – I used asitop to monitor the GPU usage during the timeit runs. Both MLX and PyTorch used 100% GPU at 1380Mhz. So at least we are sure that both frameworks were using the GPU.
  • Compiled vs non-compiled MLX: Doing mx.compile(mx.matmul(a,b)) did not make a material difference in the runtimes.

Can we reproduce the same trend on a single 128×128 matrix?

Yes. Here’s a very simple test:

mlx_a = mx.random.uniform(0, 1, (128, 128), dtype=mx.float32)
mlx_b = mx.random.uniform(0, 1, (128, 128), dtype=mx.float32)

def mlx_single_matmul():
    return mx.eval(mx.matmul(mlx_a, mlx_b))

timeit.timeit(mlx_single_matmul, number=10000)

That results in 1.15 seconds. I double checked that mx.default_device() prints Device(gpu, 0) – so MLX is indeed using the GPU. For PyTorch:

torch_a = torch.rand((128, 128), device="mps", dtype=torch.float32)
torch_b = torch.rand((128, 128), device="mps", dtype=torch.float32)

def torch_single_matmul():
    return torch.matmul(torch_a, torch_b)

timeit.timeit(torch_single_matmul, number=10000)

0.21 seconds.

I changed the matrices in the above snippet to be of shape 128×30 and 30×128. Similar results. So this is not some optimisation that PyTorch has just for square matrices.

If you have ideas on why MLX is slower, or if you spot a problem with my script that would explain the discrepancy, let me know!

Leave a comment