Skip to content

nitorch_fastmath.simplex

Overview

This module is concerned with functions that deal with data lying on the simplex, i.e., probabilities. Specifically, we implement softmax, log_softmax, logsumexp and logit.

While most of these functions already exist in PyTorch, we define more generic function that accept an "implicit" class.

This implicit class exists due to the constrained nature of discrete probabilities, which must sum to one, meaning that their space ("the simplex") has one less dimensions than the number of classes.

Similarly, we can restrain the logit (= log probability) space to be of dimension K-1 by forcing one of the classes to have logits of arbitrary value (e.g., zero). This trick makes functions like softmax invertible.

Note that in the 2-class case, it is extremely common to work in this implicit setting by using the sigmoid function over a single logit instead of the softmax function over two logits.

All functions below accept an argument implicit which takes either one (boolean) value or a tuple of two (boolean) values. The first value specifies if the input tensor has an explicit class while the second value specified if the output tensor should have an implicit class.

Note that to minimize the memory footprint and numerical errors, most backward passes are explicitly reimplemented (rather than relying on automatic differentiation). This is because these function involve multiple calls to log and exp, which must all store their input in order to backpropagate, whereas a single tensor needs to be stored to backpropagate through the entire softmax function.


logsumexp

logsumexp(input, dim=-1, keepdim=False, implicit=False)

Numerically stabilised log-sum-exp (lse).

Parameters:

Name Type Description Default
input `tensor`

Input tensor.

required
dim `int`

The dimension or dimensions to reduce.

-1
keepdim `bool`

Whether the output tensor has dim retained or not.

False
implicit `bool`

Assume that an additional (hidden) channel with value zero exists.

False

Returns:

Name Type Description
lse `tensor`

Output tensor.

softmax

softmax(input, dim=-1, implicit=False, implicit_index=0)

SoftMax (safe).

Parameters:

Name Type Description Default
input `tensor`

Tensor with values.

required
dim `int`

Dimension to take softmax, defaults to last dimensions.

-1
implicit `bool or (bool, bool)`

The first value relates to the input tensor and the second relates to the output tensor.

  • implicit[0] == True assumes that an additional (hidden) channel with value zero exists.
  • implicit[1] == True drops the last class from the softmaxed tensor.
False
implicit_index `int`

Index of the implicit class.

0

Returns:

Name Type Description
output `tensor`

Soft-maxed tensor with values.

logit

logit(input, dim=-1, implicit=False, implicit_index=0)

(Multiclass) logit function

Notes

  • \(\operatorname{logit}(\mathbf{x})_k = \log(x_k) - \log(x_K)\), where K is an arbitrary channel.
  • The logit function is the inverse of the softmax function:
    • logit(softmax(x, implicit=True), implicit=True) == x
    • softmax(logit(x, implicit=True), implicit=True) == x
  • Note that when implicit=False, softmax is surjective (many possible logits map to the same simplex value). We only have:
    • softmax(logit(x, implicit=False), implicit=False) == x
  • logit(x, implicit=True), with x.shape[dim] == 1 is equivalent to the "classical" binary logit function (inverse of the sigmoid).

Parameters:

Name Type Description Default
input `tensor`

Tensor of probabilities.

required
dim `int`

Simplex dimension, along which the logit is performed.

-1
implicit `bool or (bool, bool)`

The first value relates to the input tensor and the second relates to the output tensor.

  • implicit[0] == True assumes that an additional (hidden) channel exists, such as the sum along dim is one.
  • implicit[1] == True drops the implicit channel from the logit tensor.
False
implicit_index `int`

Index of the implicit channel. This is the channel whose logits are assumed equal to zero.

0

Returns:

Name Type Description
output `tensor`

log_softmax

log_softmax(input, dim=-1, implicit=False, implicit_index=0)

Log(SoftMax).

Parameters:

Name Type Description Default
input `tensor`

Tensor with values.

required
dim `int`

Dimension to take softmax, defaults to last dimensions.

-1
implicit `bool or (bool, bool)`

The first value relates to the input tensor and the second relates to the output tensor.

  • implicit[0] == True assumes that an additional (hidden) channel with value zero exists.
  • implicit[1] == True drops the last class from the softmaxed tensor.
False
implicit_index `int`
0

Returns:

Name Type Description
output `tensor`

Log-Softmaxed tensor with values.

softmax_lse

softmax_lse(input, dim=-1, weights=None, implicit=False)

SoftMax (safe).

Parameters:

Name Type Description Default
input `tensor`

Tensor with values.

required
dim `int`

Dimension to take softmax, defaults to last dimensions.

-1
weights `tensor`, optional:

Observation weights (only used in the log-sum-exp).

None
implicit `bool or (bool, bool)`

The first value relates to the input tensor and the second relates to the output tensor.

  • implicit[0] == True assumes that an additional (hidden) channel with value zero exists.
  • implicit[1] == True drops the last class from the softmaxed tensor.
False

Returns:

Name Type Description
softmax `tensor`

Softmaxed tensor with values.

lse `tensor`

Logsumexp