A Stable Numerical Computation
Apr 9, 2021
First, we implement the Log-Sum-Exp (LSE) in Python.
>>> import torch
>>> x = torch.randn(3,3)*100
>>> x.exp().sum(-1,keepdim=True).log()
tensor([[ inf],
[ inf],
[83.7103]])
In this example, we have tensors that have overflowed while computing LSE.
Now, we use the trick.
>>> xmax = x.max(-1)[0].unsqueeze(-1)
>>> xmax + (x-xmax).exp().sum(-1,keepdim=True).log()
tensor([[167.5672],
[165.4862],
[ 83.7103]])
Finally, we use PyTorch’s implementation of LSE for validation.
>>> x.logsumexp(-1, keepdim=True)
tensor([[167.5672],
[165.4862],
[ 83.7103]])
As you can see, the outputs are equal.