Skip to content

nitorch_fastmath.batched

Overview

I found that some torch functions (e.g., inverse() or det()) were not so efficient when applied to large batches of small matrices, especially on the GPU (this is not so true on the CPU). I reimplemented them using torchscript for 2x2 and 3x3 matrices, and they are much faster.

I used to have a batchmatmul too, but its speed was not always better than torch.matmul() (it depended a lot on the striding layout), so I removed it.


batchdet

batchdet(a)

Efficient batched determinant for large batches of small matrices

Note

A batched implementation is used for 1x1, 2x2 and 3x3 matrices. Other sizes fall back to torch.det.

Parameters:

Name Type Description Default
a (..., n, n) tensor

Input matrix.

required

Returns:

Name Type Description
d (...) tensor

Determinant.

batchinv

batchinv(a)

Efficient batched inversion for large batches of small matrices

Note

A batched implementation is used for 1x1, 2x2 and 3x3 matrices. Other sizes fall back to torch.linagl.inv.

Parameters:

Name Type Description Default
a (..., n, n) tensor

Input matrix.

required

Returns:

Name Type Description
a (..., n, n) tensor

Inverse matrix.

batchmatvec

batchmatvec(mat, vec)

Efficient batched matrix-vector product for large batches of small matrices

Note

A batched implementation is used for 1x1, 2x2 and 3x3 matrices. Other sizes fall back to matvec.

Parameters:

Name Type Description Default
mat (..., m, n) tensor

Input matrix.

required
vec (..., n) tensor

Input vector.

required

Returns:

Name Type Description
matvec (..., m) tensor

Matrix-vector product.