King Yuen Chiu, Anthony

Published on

Notes on PyTorch Matrix-Matrix Multiplication

Readings

To perform matrix-matrix multiplication between two tensors we can use @ operator in PyTorch. For example:

Matrix-Vector & Matrix-Matrix Multiplication

matrix_vector_multiplication
a = torch.randn(3, 4)
b = torch.randn(4)
print((a @ b).shape) # torch.Size([3])
print((torch.matmul(a, b)).shape) # torch.Size([3])
matrix_matrix_multiplication
a = torch.randn(3, 4)
b = torch.randn(4, 5)
print((a @ b).shape) # torch.Size([3, 5])
print((torch.matmul(a, b)).shape) # torch.Size([3, 5])

This also aligns with the mathematics. If we have two matrices ARm×nA \in \mathbb{R}^{m \times n} and BRn×pB \in \mathbb{R}^{n \times p}, then the matrix-matrix product is CRm×p=ABC \in \mathbb{R}^{m \times p} = AB.

However, when the participating tensors are with higher dimensions, it becomes ambiguous that how the multiplication should be performed.

Batched Matrix-Matrix Multiplication

Let's say now we work with 3D tensors ARn×m×nA \in \mathbb{R}^{n \times m \times n} and BRb×n×pB \in \mathbb{R}^{b \times n \times p}. This is like the batched version of the previous example. The batch dimension doesn't participate in the multiplication. The result is a 3D tensor CRb×m×pC \in \mathbb{R}^{b \times m \times p}.

batched_matrix_matrix_multiplication
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 5)
print((a @ b).shape) # torch.Size([10, 3, 5])
print((torch.bmm(a, b)).shape) # torch.Size([10, 3, 5])
print((torch.matmul(a, b)).shape) # torch.Size([10, 3, 5])
assert (a @ b == torch.bmm(a, b)).all() # True
assert (a @ b == torch.matmul(a, b)).all() # True

torch.bmm is specifically for batched matrix-matrix multiplication. It expects the input tensors to be 3D. @ and torch.matmul are more flexible. They can handle tensors with arbitrary dimensions but are also more confusing.

Higher Dimensional Matrix-Matrix Multiplication

When the participating tensors are with higher dimensions, only the last two dimensions participate in the multiplication.

higher_dimensional_matrix_matrix_multiplication
a = torch.randn(64, 10, 3, 4)
b = torch.randn(64, 10, 4, 5)
print((a @ b).shape) # torch.Size([64, 10, 3, 5])
print((torch.matmul(a, b)).shape) # torch.Size([64, 10, 3, 5])
assert (a @ b == torch.matmul(a, b)).all() # True