PyTorch implementation of the implicit reparametrisation trick for mixture distributions based on Figurnov et al., 2019, "Implicit Reparameterization Gradients" and the implementation in Tensorflow Probability.
Can be readily used for variational inference with mixture distribution variational families.
Remarks:
- For multivariate mixtures, the class is currently implemented when the mixture component distributions fully factorise.
- Also added a
StableNormaldistribution, which overrides the defaultcdfmethod with a more stable implementation from pytorch/pytorch#52973 (comment). The implementation also provides a_log_cdfmethod, however it is not used for the implicit reparametrisation.