Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 19 additions & 50 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorConstant, TensorVariable

from pymc.distributions.custom import CustomDist
from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import TensorLike, icdf
from pymc.pytensorf import normalize_rng_param
Expand Down Expand Up @@ -92,7 +93,7 @@ def polyagamma_cdf(*args, **kwargs):
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.distributions.transforms import _default_transform
from pymc.math import invlogit, logdiffexp, logit
from pymc.math import invlogit, logdiffexp

__all__ = [
"AsymmetricLaplace",
Expand Down Expand Up @@ -3603,28 +3604,7 @@ def icdf(value, mu, s):
)


class LogitNormalRV(SymbolicRandomVariable):
name = "logit_normal"
extended_signature = "[rng],[size],(),()->[rng],()"
_print_name = ("LogitNormal", "\\operatorname{LogitNormal}")

@classmethod
def rv_op(cls, mu, sigma, *, size=None, rng=None):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
draws = pt.expit(normal_draws)

return cls(
inputs=[rng, size, mu, sigma],
outputs=[next_rng, draws],
)(rng, size, mu, sigma)


class LogitNormal(UnitContinuous):
class LogitNormal:
r"""
Logit-Normal distribution.

Expand Down Expand Up @@ -3672,37 +3652,26 @@ class LogitNormal(UnitContinuous):
Defaults to 1.
"""

rv_type = LogitNormalRV
rv_op = LogitNormalRV.rv_op
@staticmethod
def logitnormal_dist(mu, sigma, size):
return invlogit(Normal.dist(mu=mu, sigma=sigma, size=size))

@classmethod
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
def __new__(cls, name, mu=0, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return super().dist([mu, sigma], **kwargs)

def support_point(rv, size, mu, sigma):
median, _ = pt.broadcast_arrays(invlogit(mu), sigma)
if not rv_size_is_none(size):
median = pt.full(size, median)
return median

def logp(value, mu, sigma):
tau, _ = get_tau_sigma(sigma=sigma)

res = pt.switch(
pt.or_(pt.le(value, 0), pt.ge(value, 1)),
-np.inf,
(
-0.5 * tau * (logit(value) - mu) ** 2
+ 0.5 * pt.log(tau / (2.0 * np.pi))
- pt.log(value * (1 - value))
),
return CustomDist(
name,
mu,
sigma,
dist=cls.logitnormal_dist,
class_name="LogitNormal",
**kwargs,
)

return check_parameters(
res,
tau > 0,
msg="tau > 0",
@classmethod
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return CustomDist.dist(
mu, sigma, dist=cls.logitnormal_dist, class_name="LogitNormal", **kwargs
)


Expand Down
6 changes: 0 additions & 6 deletions pymc/distributions/moments/means.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
HalfFlatRV,
HalfStudentTRV,
KumaraswamyRV,
LogitNormalRV,
MoyalRV,
PolyaGammaRV,
RiceRV,
Expand Down Expand Up @@ -290,11 +289,6 @@ def logistic_mean(op, rv, rng, size, mu, s):
return maybe_resize(pt.broadcast_arrays(mu, s)[0], size)


@_mean.register(LogitNormalRV)
def logitnormal_mean(op, rv, rng, size, mu, sigma):
raise UndefinedMomentException("The mean of the LogitNormal distribution is undefined")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small note: It isn't really undefined, just doesn't have a closed form solution



@_mean.register(LogNormalRV)
def lognormal_mean(op, rv, rng, size, mu, sigma):
return maybe_resize(pt.exp(mu + 0.5 * sigma**2), size)
Expand Down
16 changes: 14 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class MeasurableTransform(MeasurableElemwise):
Erf,
Erfc,
Erfcx,
Sigmoid,
)

# Cannot use `transform` as name because it would clash with the property added by
Expand Down Expand Up @@ -227,7 +228,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)


MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf)
MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid)
MONOTONICALLY_DECREASING_OPS = (Erfc, Erfcx)


Expand Down Expand Up @@ -300,7 +301,18 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
value = pt.switch(pt.lt(scale, 0), 1 - value, value)
elif isinstance(op.scalar_op, Pow):
if op.transform_elemwise.power < 0:
raise NotImplementedError
# Note: Negative even powers will be rejected below when inverting the transform
# For the remaining negative powers the function is decreasing with a jump around 0
# We adjust the value with the mass below zero.
# For non-negative RVs with cdf(0)=0, it simplifies to 1 - value
cdf_zero = pt.exp(_logcdf_helper(measurable_input, 0))
# Use nan to not mask invalid values accidentally
value = pt.switch((value >= 0) & (value <= 1), value, np.nan)
value = pt.switch(
(cdf_zero > 0) & (value < cdf_zero),
cdf_zero - value,
1 + cdf_zero - value,
)
else:
raise NotImplementedError

Expand Down
6 changes: 6 additions & 0 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,12 @@ def test_logitnormal(self):
),
decimal=select_by_precision(float64=6, float32=1),
)
check_icdf(
pm.LogitNormal,
{"mu": R, "sigma": Rplus},
lambda q, mu, sigma: sp.expit(mu + sigma * st.norm.ppf(q)),
decimal=select_by_precision(float64=12, float32=5),
)

@pytest.mark.skipif(
condition=(pytensor.config.floatX == "float32"),
Expand Down
14 changes: 9 additions & 5 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,7 @@ def test_reciprocal_rv_transform(self, numerator):
x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))

with pytest.raises(NotImplementedError):
icdf(x_rv, x_vv)
x_icdf_fn = pytensor.function([x_vv], icdf(x_rv, x_vv))

x_test_val = np.r_[-0.5, 1.5]
np.testing.assert_allclose(
Expand All @@ -392,6 +390,10 @@ def test_reciprocal_rv_transform(self, numerator):
x_logcdf_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logcdf(x_test_val),
)
np.testing.assert_allclose(
x_icdf_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).ppf(x_test_val),
)

def test_reciprocal_real_rv_transform(self):
# 1 / Cauchy(mu, sigma) = Cauchy(mu / (mu^2 + sigma ^2), sigma / (mu ^ 2, sigma ^ 2))
Expand All @@ -406,8 +408,10 @@ def test_reciprocal_real_rv_transform(self):
logcdf(test_rv, test_value).eval(),
sp.stats.cauchy(1 / 5, 2 / 5).logcdf(test_value),
)
with pytest.raises(NotImplementedError):
icdf(test_rv, test_value)
np.testing.assert_allclose(
icdf(test_rv, test_value).eval(),
sp.stats.cauchy(1 / 5, 2 / 5).ppf(test_value),
)

def test_sqr_transform(self):
# The square of a normal with unit variance is a noncentral chi-square with 1 df and nc = mean ** 2
Expand Down
Loading