nitorch_fastmath.batched
Overview
I found that some torch functions (e.g., inverse() or det()) were
not so efficient when applied to large batches of small matrices,
especially on the GPU (this is not so true on the CPU). I reimplemented
them using torchscript for 2x2 and 3x3 matrices, and they are much
faster.
I used to have a batchmatmul too, but its speed was not always better
than torch.matmul() (it depended a lot on the striding layout),
so I removed it.
batchdet
Efficient batched determinant for large batches of small matrices
Note
A batched implementation is used for 1x1, 2x2 and 3x3 matrices.
Other sizes fall back to torch.det.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a |
(..., n, n) tensor
|
Input matrix. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
d |
(...) tensor
|
Determinant. |
batchinv
Efficient batched inversion for large batches of small matrices
Note
A batched implementation is used for 1x1, 2x2 and 3x3 matrices.
Other sizes fall back to torch.linagl.inv.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a |
(..., n, n) tensor
|
Input matrix. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
a |
(..., n, n) tensor
|
Inverse matrix. |
batchmatvec
Efficient batched matrix-vector product for large batches of small matrices
Note
A batched implementation is used for 1x1, 2x2 and 3x3 matrices.
Other sizes fall back to matvec.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mat |
(..., m, n) tensor
|
Input matrix. |
required |
vec |
(..., n) tensor
|
Input vector. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
matvec |
(..., m) tensor
|
Matrix-vector product. |