Disclaimer: I do not know why PyTorch + MPS is faster (yet)
I recently came across Apple’s MLX framework. It is a numpy-like interface that will let you run computations on Apple’s own GPUs if you have a Mac with an m-series chip. It uses Apple’s metal API under the hood. I wanted to see how much faster would it be to do matrix multiplications on apple silicon using MLX compared to Numpy + CPU, and more importantly, PyTorch. PyTorch has been supporting device="mps" for a while now.
Results
MLX + GPU (M3 Pro) is faster than Numpy + CPU. No surprises there, so I’ll omit CPU from the plots below. But MLX + GPU was surprisingly slower than PyTorch 2.6 + GPU on my MacBook with the M3 Pro chip.

I am going to ignore the more interesting phenomenon for now – that the matrix init times are slower for PyTorch when the matrix dimensions start to become huge. For the subplot titled “Matrix multiplication times”, this is not what I expected – I had expected PyTorch + MPS to be on par or slightly slower than MLX. . I still do not know why MLX is slower, but these were the hypotheses I have considered:
- Very likely there’s something wrong with my benchmarking approach. This is the most likely hypothesis. Here’s the full notebook that I used to benchmark MLX.
- Dtypes – Both mlx and torch arrays were explicitly set to float32, so it can’t be this.
- GPU utilisation – I used
asitopto monitor the GPU usage during thetimeitruns. Both MLX and PyTorch used 100% GPU at 1380Mhz. So at least we are sure that both frameworks were using the GPU. - Compiled vs non-compiled MLX: Doing
mx.compile(mx.matmul(a,b))did not make a material difference in the runtimes.
Can we reproduce the same trend on a single 128×128 matrix?
Yes. Here’s a very simple test:
mlx_a = mx.random.uniform(0, 1, (128, 128), dtype=mx.float32)
mlx_b = mx.random.uniform(0, 1, (128, 128), dtype=mx.float32)
def mlx_single_matmul():
return mx.eval(mx.matmul(mlx_a, mlx_b))
timeit.timeit(mlx_single_matmul, number=10000)
That results in 1.15 seconds. I double checked that mx.default_device() prints Device(gpu, 0) – so MLX is indeed using the GPU. For PyTorch:
torch_a = torch.rand((128, 128), device="mps", dtype=torch.float32)
torch_b = torch.rand((128, 128), device="mps", dtype=torch.float32)
def torch_single_matmul():
return torch.matmul(torch_a, torch_b)
timeit.timeit(torch_single_matmul, number=10000)
0.21 seconds.
I changed the matrices in the above snippet to be of shape 128×30 and 30×128. Similar results. So this is not some optimisation that PyTorch has just for square matrices.
If you have ideas on why MLX is slower, or if you spot a problem with my script that would explain the discrepancy, let me know!