A Stable Numerical Computation

Photo by Patrick Fore on Unsplash

First, we implement the Log-Sum-Exp (LSE) in Python.

>>> import torch
>>> x = torch.randn(3,3)*100
>>> x.exp().sum(-1,keepdim=True).log()
tensor([[ inf],
[ inf],
[83.7103]])

In this example, we have tensors that have overflowed while computing LSE.

Now, we use the trick.

>>> xmax = x.max(-1)[0].unsqueeze(-1)
>>> xmax + (x-xmax).exp().sum(-1,keepdim=True).log()
tensor([[167.5672],
[165.4862],
[ 83.7103]])

Finally, we use PyTorch’s implementation of LSE for validation.

>>> x.logsumexp(-1, keepdim=True)
tensor([[167.5672],
[165.4862],
[ 83.7103]])

As you can see, the outputs are equal.

--

--