Describe the bug
TorchBackend.sqrtm relies on torch.linalg.eigh which has undefined gradients when eigvals are repeated (PyTorch's doc explains the issue).
To Reproduce
import torch
from ot.backend import TorchBackend
torch.set_default_dtype(torch.float64)
torch.autograd.set_detect_anomaly(True)
nx = TorchBackend()
A = torch.eye(3, dtype=torch.float64, requires_grad=True)
nx.sqrtm(A)[0, 1].backward()
print('OK')
Output: RuntimeError: Function 'LinalgEighBackward0' returned nan values in its 0th output.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): MacOS
- Python version: 9
- How was POT installed (source,
pip, conda): pip