nitorch_fastmath.simplex
Overview
This module is concerned with functions that deal with data lying on the simplex,
i.e., probabilities.
Specifically, we implement softmax, log_softmax, logsumexp and logit.
While most of these functions already exist in PyTorch, we define more generic function that accept an "implicit" class.
This implicit class exists due to the constrained nature of discrete probabilities, which must sum to one, meaning that their space ("the simplex") has one less dimensions than the number of classes.
Similarly, we can restrain the logit (= log probability) space to be of dimension K-1 by forcing one of the classes to have logits of arbitrary value (e.g., zero). This trick makes functions like softmax invertible.
Note that in the 2-class case, it is extremely common to work in this implicit setting by using the sigmoid function over a single logit instead of the softmax function over two logits.
All functions below accept an argument implicit which takes either one
(boolean) value or a tuple of two (boolean) values. The first value
specifies if the input tensor has an explicit class while the second value
specified if the output tensor should have an implicit class.
Note that to minimize the memory footprint and numerical errors, most
backward passes are explicitly reimplemented (rather than relying on
automatic differentiation). This is because these function involve multiple
calls to log and exp, which must all store their input in order to
backpropagate, whereas a single tensor needs to be stored to backpropagate
through the entire softmax function.
logsumexp
Numerically stabilised log-sum-exp (lse).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input |
`tensor`
|
Input tensor. |
required |
dim |
`int`
|
The dimension or dimensions to reduce. |
-1
|
keepdim |
`bool`
|
Whether the output tensor has dim retained or not. |
False
|
implicit |
`bool`
|
Assume that an additional (hidden) channel with value zero exists. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
lse |
`tensor`
|
Output tensor. |
softmax
SoftMax (safe).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input |
`tensor`
|
Tensor with values. |
required |
dim |
`int`
|
Dimension to take softmax, defaults to last dimensions. |
-1
|
implicit |
`bool or (bool, bool)`
|
The first value relates to the input tensor and the second relates to the output tensor.
|
False
|
implicit_index |
`int`
|
Index of the implicit class. |
0
|
Returns:
| Name | Type | Description |
|---|---|---|
output |
`tensor`
|
Soft-maxed tensor with values. |
logit
(Multiclass) logit function
Notes
- \(\operatorname{logit}(\mathbf{x})_k = \log(x_k) - \log(x_K)\), where K is an arbitrary channel.
- The
logitfunction is the inverse of thesoftmaxfunction:logit(softmax(x, implicit=True), implicit=True) == xsoftmax(logit(x, implicit=True), implicit=True) == x
- Note that when
implicit=False,softmaxis surjective (many possible logits map to the same simplex value). We only have:softmax(logit(x, implicit=False), implicit=False) == x
logit(x, implicit=True), withx.shape[dim] == 1is equivalent to the "classical" binary logit function (inverse of the sigmoid).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input |
`tensor`
|
Tensor of probabilities. |
required |
dim |
`int`
|
Simplex dimension, along which the logit is performed. |
-1
|
implicit |
`bool or (bool, bool)`
|
The first value relates to the input tensor and the second relates to the output tensor.
|
False
|
implicit_index |
`int`
|
Index of the implicit channel. This is the channel whose logits are assumed equal to zero. |
0
|
Returns:
| Name | Type | Description |
|---|---|---|
output |
`tensor`
|
|
log_softmax
Log(SoftMax).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input |
`tensor`
|
Tensor with values. |
required |
dim |
`int`
|
Dimension to take softmax, defaults to last dimensions. |
-1
|
implicit |
`bool or (bool, bool)`
|
The first value relates to the input tensor and the second relates to the output tensor.
|
False
|
implicit_index |
`int`
|
|
0
|
Returns:
| Name | Type | Description |
|---|---|---|
output |
`tensor`
|
Log-Softmaxed tensor with values. |
softmax_lse
SoftMax (safe).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input |
`tensor`
|
Tensor with values. |
required |
dim |
`int`
|
Dimension to take softmax, defaults to last dimensions. |
-1
|
weights |
`tensor`, optional:
|
Observation weights (only used in the log-sum-exp). |
None
|
implicit |
`bool or (bool, bool)`
|
The first value relates to the input tensor and the second relates to the output tensor.
|
False
|
Returns:
| Name | Type | Description |
|---|---|---|
softmax |
`tensor`
|
Softmaxed tensor with values. |
lse |
`tensor`
|
Logsumexp |