diff --git a/examples/example-parametric.ipynb b/examples/example-parametric.ipynb new file mode 100644 index 0000000..b874ad3 --- /dev/null +++ b/examples/example-parametric.ipynb @@ -0,0 +1,244 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "91dcd026-de5f-4888-b668-f4c88ae3d7ac", + "metadata": {}, + "source": [ + "# Example: Creating parametric families #1" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "54cdd27a-7183-4062-960b-cd28cd2bc521", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from pysatl_core.distributions import (\n", + " DefaultComputationStrategy,\n", + " DefaultSamplingUnivariateStrategy,\n", + ")\n", + "from pysatl_core.families import (\n", + " ParametricFamily,\n", + " ParametricFamilyRegister,\n", + " constraint,\n", + " parametrization,\n", + ")\n", + "from pysatl_core.types import UnivariateContinuous\n", + "\n", + "PDF = \"pdf\"\n", + "CDF = \"cdf\"\n", + "PPF = \"ppf\"" + ] + }, + { + "cell_type": "markdown", + "id": "4854203d-24fe-4abf-ac3d-773dbbdab56b", + "metadata": {}, + "source": [ + "Let's create lognormal family of random variables, by providing PDF" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "73ccf204-fba8-4087-9f21-bdf59f04dade", + "metadata": {}, + "outputs": [], + "source": [ + "def lognormal_pdf(parameters, x: float) -> float:\n", + " if x <= 0:\n", + " return 0.0\n", + " return (\n", + " 1.0\n", + " / (x * parameters.sigma * math.sqrt(2.0 * math.pi))\n", + " * math.exp(-((math.log(x) - parameters.mu) ** 2) / (2 * parameters.sigma**2))\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "f33585a5-02b0-4a2d-89e2-910f61dd63d8", + "metadata": {}, + "source": [ + "Now let's create an object, that will represent our family" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "680acb78-fde1-45d8-84fc-50d3058436c8", + "metadata": {}, + "outputs": [], + "source": [ + "Lognormal = ParametricFamily(\n", + " name=\"Lognormal Family\",\n", + " distr_type=UnivariateContinuous,\n", + " distr_parametrizations=[\"canonical\", \"meanvar\"],\n", + " distr_characteristics={PDF: lognormal_pdf},\n", + " sampling_strategy=DefaultSamplingUnivariateStrategy(),\n", + " computation_strategy=DefaultComputationStrategy(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6709fd2b-45c6-4f66-8b0f-c9e5cd31bfad", + "metadata": {}, + "source": [ + "We specified that there will be two parametrizations: canonical (which will be treat as base) and mean-var. Let's introduce them" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5a7da55b-4396-4cf8-bc87-328e7590e0d7", + "metadata": {}, + "outputs": [], + "source": [ + "@parametrization(family=Lognormal, name=\"canonical\")\n", + "class NormalParametrization:\n", + " mu: float\n", + " sigma: float\n", + "\n", + " @constraint(description=\"sigma > 0\")\n", + " def check_sigma_positive(self) -> bool:\n", + " return self.sigma > 0\n", + "\n", + "\n", + "@parametrization(family=Lognormal, name=\"meanvar\")\n", + "class MeanVarParametrization:\n", + " mean: float\n", + " var: float\n", + "\n", + " @constraint(description=\"mean > 0\")\n", + " def check_mean_positive(self) -> bool:\n", + " return self.mean > 0\n", + "\n", + " @constraint(description=\"var > 0\")\n", + " def check_var_positive(self) -> bool:\n", + " return self.var > 0\n", + "\n", + " def transform_to_base_parametrization(self) -> NormalParametrization:\n", + " mu = math.log(self.mean**2 / math.sqrt(self.mean**2 + self.var))\n", + " sigma = math.sqrt(math.log(1 + self.var / self.mean**2))\n", + " return NormalParametrization(mu=mu, sigma=sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "0f11c1b8-4601-4c1c-a087-50eeb3134a89", + "metadata": {}, + "source": [ + "Note that for second parametrization we provided way convert it to base one. Now, let's register our family and do some things:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "60b6b5f2-0f6d-48f9-a45c-d927a7cc289e", + "metadata": {}, + "outputs": [], + "source": [ + "ParametricFamilyRegister.register(Lognormal)\n", + "dist = Lognormal(mean=1.0, var=1.0, parametrization_name=\"meanvar\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2b631a99-42c3-4708-ad79-663786283814", + "metadata": {}, + "outputs": [], + "source": [ + "cdf = dist.computation_strategy.query_method(CDF, dist)\n", + "pdf = dist.computation_strategy.query_method(PDF, dist)\n", + "ppf = dist.computation_strategy.query_method(PPF, dist)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "cc1702cf-fda3-4897-b918-21515aff2a7b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "delta_t = 0.01\n", + "x_max = 2\n", + "x = [i * delta_t for i in range(int(x_max / delta_t))]\n", + "y_pdf = [pdf(xx) for xx in x]\n", + "y_cdf = [cdf(xx) for xx in x]\n", + "fig, ax = plt.subplots(2, 1)\n", + "ax[0].plot(x, y_pdf)\n", + "ax[1].plot(x, y_cdf)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "adeb977a-4f70-4773-a98a-02865b437d67", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0\n" + ] + } + ], + "source": [ + "base_parameters = Lognormal.parametrizations.get_base_parameters(dist.parameters)\n", + "y_true_pdf = [lognormal_pdf(base_parameters, xx) for xx in x]\n", + "print(max([y_true_pdf[i] - y_pdf[i] for i in range(len(x))]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index c19737f..63c7ffd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,10 @@ coverage = { version = ">=7.5", extras = [ "toml" ] } pre-commit = ">=3.6" types-setuptools = "*" +[tool.poetry.group.docs.dependencies] +jupyter = "^1.1.1" +matplotlib = "^3.10.6" + [tool.ruff] target-version = "py312" line-length = 100 diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py new file mode 100644 index 0000000..b22cbf0 --- /dev/null +++ b/src/pysatl_core/families/__init__.py @@ -0,0 +1,35 @@ +""" +Parametric Families module for working with statistical distribution families. + +This package provides a comprehensive framework for defining, managing, and +working with parametric families of statistical distributions. It supports +multiple parameterizations, constraint validation, and automatic conversion +between different parameter formats. +""" + +__author__ = "Leonid Elkin, Mikhail, Mikhailov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + + +from pysatl_core.families.distribution import ParametricFamilyDistribution +from pysatl_core.families.parametric_family import ParametricFamily +from pysatl_core.families.parametrizations import ( + Parametrization, + ParametrizationConstraint, + ParametrizationSpec, + constraint, + parametrization, +) +from pysatl_core.families.registry import ParametricFamilyRegister + +__all__ = [ + "ParametricFamilyRegister", + "ParametrizationConstraint", + "Parametrization", + "ParametrizationSpec", + "ParametricFamily", + "ParametricFamilyDistribution", + "constraint", + "parametrization", +] diff --git a/src/pysatl_core/families/distribution.py b/src/pysatl_core/families/distribution.py new file mode 100644 index 0000000..750c33e --- /dev/null +++ b/src/pysatl_core/families/distribution.py @@ -0,0 +1,151 @@ +""" +Concrete distribution instances with specific parameter values. + +This module provides the implementation for individual distribution instances +created from parametric families. It handles distribution characteristics +computation, sampling, and provides access to analytical methods for +specific parameter sets. +""" + +from __future__ import annotations + +__author__ = "Leonid Elkin, Mikhail, Mikhailov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from collections.abc import Mapping +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any + +from pysatl_core.distributions import ( + AnalyticalComputation, + ComputationStrategy, + Sample, + SamplingStrategy, +) +from pysatl_core.families.parametrizations import Parametrization +from pysatl_core.families.registry import ParametricFamilyRegister +from pysatl_core.types import ( + DistributionType, + GenericCharacteristicName, +) + +if TYPE_CHECKING: + from pysatl_core.families.parametric_family import ParametricFamily + + +@dataclass +class ParametricFamilyDistribution: + """ + A specific distribution instance from a parametric family. + + This class represents a concrete distribution with specific parameter + values, providing methods for computation and sampling. + + Attributes + ---------- + distr_name : str + Name of the distribution family. + distribution_type : DistributionType + Type of this distribution. + parameters : Parametrization + Parameter values for this distribution. + """ + + distr_name: str + distribution_type: DistributionType + parameters: Parametrization + + @property + def family(self) -> ParametricFamily: + """ + Get the parametric family this distribution belongs to. + + Returns + ------- + ParametricFamily + The parametric family of this distribution. + """ + return ParametricFamilyRegister.get(self.distr_name) + + @property + def analytical_computations( + self, + ) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: + """ + Get analytical computation functions for this distribution. + + Returns + ------- + Mapping[GenericCharacteristicName, AnalyticalComputation] + Mapping from characteristic names to computation functions. + """ + analytical_computations = {} + + # First form list of all characteristics, available from current parametrization + for characteristic, forms in self.family.distr_characteristics.items(): + if self.parameters.name in forms: + analytical_computations[characteristic] = AnalyticalComputation( + target=characteristic, + func=partial(forms[self.parameters.name], self.parameters), + ) + # TODO: Second, apply rule set, for, e.g. approximations + + # Finally, fill other chacteristics + base_name = self.family.parametrizations.base_parametrization_name + base_parameters = self.family.parametrizations.get_base_parameters(self.parameters) + for characteristic, forms in self.family.distr_characteristics.items(): + if characteristic in analytical_computations: + continue + if base_name in forms: + analytical_computations[characteristic] = AnalyticalComputation( + target=characteristic, func=partial(forms[base_name], base_parameters) + ) + + return analytical_computations + + @property + def sampling_strategy(self) -> SamplingStrategy: + """ + Get the sampling strategy for this distribution. + + Returns + ------- + SamplingStrategy + Strategy for sampling from this distribution. + """ + return self.family.sampling_strategy + + @property + def computation_strategy(self) -> ComputationStrategy[Any, Any]: + """ + Get the computation strategy for this distribution. + + Returns + ------- + ComputationStrategy + Strategy for computing characteristics of this distribution. + """ + return self.family.computation_strategy + + def log_likelihood(self, batch: Sample) -> float: + raise NotImplementedError + + def sample(self, n: int, **options: Any) -> Sample: + """ + Generate samples from this distribution. + + Parameters + ---------- + n : int + Number of samples to generate. + **options : Any + Additional options for the sampling algorithm. + + Returns + ------- + Sample + The generated samples. + """ + return self.sampling_strategy.sample(n, distr=self, **options) diff --git a/src/pysatl_core/families/parametric_family.py b/src/pysatl_core/families/parametric_family.py new file mode 100644 index 0000000..e655efe --- /dev/null +++ b/src/pysatl_core/families/parametric_family.py @@ -0,0 +1,157 @@ +""" +Parametric family definitions and management infrastructure. + +This module contains the main class for defining parametric families of +distributions, including support for multiple parameterizations, distribution +characteristics, sampling strategies, and computation methods. It serves as +the central definition point for statistical distribution families. +""" + +from __future__ import annotations + +__author__ = "Leonid Elkin, Mikhail, Mikhailov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from collections.abc import Callable +from typing import Any + +from pysatl_core.distributions import ( + ComputationStrategy, + SamplingStrategy, +) +from pysatl_core.families.distribution import ParametricFamilyDistribution +from pysatl_core.families.parametrizations import Parametrization, ParametrizationSpec +from pysatl_core.types import ( + DistributionType, + GenericCharacteristicName, + ParametrizationName, +) + +type ParametrizedFunction = Callable[[Parametrization, Any], Any] + + +class ParametricFamily: + """ + A family of distributions with multiple parametrizations. + + This class represents a parametric family of distributions, such as + the normal or lognormal family, which can be parameterized in different + ways (e.g., mean-variance or canonical parametrization). + + Attributes + ---------- + name : str + Name of the distribution family. + distr_type : DistributionType | Callable[[Parametrization] | DistributionType] + Type of distributions in this family. + parametrizations : ParametrizationSpec + Specification of available parametrizations. + distr_parametrizations : + + distr_characteristics : Dict[GenericCharacteristicName, Callable[[Any, Any], Any]] + Mapping from characteristic names to computation functions. + sampling_strategy : SamplingStrategy + Strategy for sampling from distributions in this family. + computation_strategy : ComputationStrategy + Strategy for computing distribution characteristics. + """ + + def __init__( + self, + name: str, + distr_type: DistributionType | Callable[[Parametrization], DistributionType], + distr_parametrizations: list[ParametrizationName], + distr_characteristics: dict[ + GenericCharacteristicName, + dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, + ], + sampling_strategy: SamplingStrategy, + computation_strategy: ComputationStrategy[Any, Any], + ): + """ + Initialize a new parametric family. + + Parameters + ---------- + name : str + Name of the distribution family. + + distr_type : DistributionType | Callable[[Parametrization], DistributionType] + Type of distributions in this family or, if type is parameter-depended, function + that takes as input *base* parametrization and inferes type based on it. + + distr_parametrizations : List[ParametrizationName] + List of parametrizations for this distribution. *First parametrization is always + base parametrization*. + + distr_characteristics: + Mapping from characteristics names to computation functions or dictionary of those, + if for multiple parametrizations same characteristic available. + + sampling_strategy : SamplingStrategy + Strategy for sampling from distributions in this family. + + computation_strategy : ComputationStrategy + Strategy for computing distribution characteristics. + """ + self._name = name + self._distr_type: Callable[[Parametrization], DistributionType] = ( + (lambda params: distr_type) if isinstance(distr_type, DistributionType) else distr_type + ) + + # Parametrizations must be built by user + self.parametrization_names = distr_parametrizations + self.parametrizations = ParametrizationSpec(self.parametrization_names[0]) + + self.sampling_strategy = sampling_strategy + self.computation_strategy = computation_strategy + + def _process_char_val( + value: dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, + ) -> dict[ParametrizationName, ParametrizedFunction]: + return value if isinstance(value, dict) else {self.parametrization_names[0]: value} + + self.distr_characteristics = { + key: _process_char_val(value) for key, value in distr_characteristics.items() + } + + @property + def name(self) -> str: + return self._name + + def distribution( + self, parametrization_name: str | None = None, **parameters_values: Any + ) -> ParametricFamilyDistribution: + """ + Create a distribution instance with the given parameters. + + Parameters + ---------- + parametrization_name : str | None, optional + Name of the parametrization to use, or None for base parametrization. + **parameters_values + Parameter values for the distribution. + + Returns + ------- + ParametricFamilyDistribution + A distribution instance with the specified parameters. + + Raises + ------ + ValueError + If the parameters don't satisfy the parametrization constraints. + """ + if parametrization_name is None: + parametrization_class = self.parametrizations.base + else: + parametrization_class = self.parametrizations.parametrizations[parametrization_name] + + parameters = parametrization_class(**parameters_values) + base_parameters = self.parametrizations.get_base_parameters(parameters) + parameters.validate() + distribution_type = self._distr_type(base_parameters) + return ParametricFamilyDistribution(self.name, distribution_type, parameters) + + __call__ = distribution diff --git a/src/pysatl_core/families/parametrizations.py b/src/pysatl_core/families/parametrizations.py new file mode 100644 index 0000000..2f6a7af --- /dev/null +++ b/src/pysatl_core/families/parametrizations.py @@ -0,0 +1,314 @@ +""" +Parameterization classes and specifications for distribution families. + +This module provides the core classes for defining different parameterizations +of statistical distributions, including constraints validation and conversion +between parameterization formats. +""" + +from __future__ import annotations + +__author__ = "Leonid Elkin, Mikhail, Mikhailov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from functools import wraps +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, runtime_checkable + +from pysatl_core.types import ParametrizationName + +if TYPE_CHECKING: + from pysatl_core.families.parametric_family import ParametricFamily + + +@runtime_checkable +class ParametrizationConstraintProtocol(Protocol): + @property + def _is_constraint(self) -> bool: ... + @property + def _constraint_description(self) -> str: ... + + def __call__(self, **kwargs: Any) -> bool: ... + + +@dataclass(slots=True, frozen=True) +class ParametrizationConstraint: + """ + A constraint on parameter values for a parametrization. + + Attributes + ---------- + description : str + Human-readable description of the constraint. + check : Callable[[Any], bool] + Function that validates the constraint, returning True if satisfied. + """ + + description: str + check: Callable[[Any], bool] + + +class Parametrization(ABC): + """ + Abstract base class for distribution parametrizations. + + This class defines the interface that all parametrizations must implement, + including parameter validation and conversion to base parametrization. + + Attributes + ---------- + constraints : ClassVar[List[ParametrizationConstraint]] + Class-level list of constraints that apply to this parametrization. + """ + + _constraints: ClassVar[list[ParametrizationConstraint]] = [] + + @property + @abstractmethod + def name(self) -> str: + """ + Get the name of this parametrization. + + Returns + ------- + str + The name of the parametrization. + """ + + @property + @abstractmethod + def parameters(self) -> dict[str, Any]: + """ + Get the parameters as a dictionary. + + Returns + ------- + Dict[str, Any] + Dictionary mapping parameter names to values. + """ + + @property + def constraints(self) -> list[ParametrizationConstraint]: + """ + Get the constraints for this parametrization. + + Returns + ------- + List[ParametrizationConstraint] + List of constraints that apply to this parametrization. + """ + return self._constraints + + def validate(self) -> None: + """ + Validate all constraints for this parametrization. + + Raises + ------ + ValueError + If any constraint is not satisfied. + """ + for constraint in self._constraints: + if not constraint.check(self): + raise ValueError(f'Constraint "{constraint.description}" does not hold') + + def transform_to_base_parametrization(self) -> Parametrization: + """ + Convert this parametrization to the base parametrization. + + Returns + ------- + Parametrization + The equivalent parameters in the base parametrization. + + Notes + ----- + The base implementation returns self, assuming this is already + the base parametrization. Subclasses should override this method + if they need to convert to a different parametrization. + """ + return self + + +class ParametrizationSpec: + """ + Container for all parametrizations of a distribution family. + + This class manages the collection of parametrizations for a family + and handles conversions between them. + + Attributes + ---------- + parametrizations : Dict[ParametrizationName, Type[Parametrization]] + Mapping from parametrization names to parametrization classes. + base_parametrization_name : ParametrizationName | None + Name of the base parametrization, if defined. + """ + + def __init__(self, base_name: ParametrizationName) -> None: + """Initialize an empty parametrization specification.""" + self.parametrizations: dict[ParametrizationName, type[Parametrization]] = {} + self.base_parametrization_name: ParametrizationName = base_name + + @property + def base(self) -> type[Parametrization]: + """ + Get the base parametrization class. + + Returns + ------- + Type[Parametrization] + The base parametrization class. + + Raises + ------ + ValueError + If no base parametrization has been defined. + """ + if self.base_parametrization_name is None: + raise ValueError("No base parametrization defined") + return self.parametrizations[self.base_parametrization_name] + + def add_parametrization( + self, + name: ParametrizationName, + parametrization_class: type[Parametrization], + ) -> None: + """ + Add a new parametrization to the specification. + + Parameters + ---------- + name : ParametrizationName + Name of the parametrization. + parametrization_class : Type[Parametrization] + Class implementing the parametrization. + """ + self.parametrizations[name] = parametrization_class + + def get_base_parameters(self, parameters: Parametrization) -> Parametrization: + """ + Convert parameters to the base parametrization. + + Parameters + ---------- + parameters : Parametrization + Parameters in any parametrization. + + Returns + ------- + Parametrization + Equivalent parameters in the base parametrization. + """ + if parameters.name == self.base_parametrization_name: + return parameters + else: + return parameters.transform_to_base_parametrization() + + +# Decorators for declarative syntax +def constraint( + description: str, +) -> Callable[[Callable[[Any], bool]], ParametrizationConstraintProtocol]: + """ + Decorator to mark a method as a parameter constraint. + + Parameters + ---------- + description : str + Human-readable description of the constraint. + + Returns + ------- + Callable + Decorator function that marks the method as a constraint. + + Examples + -------- + >>> @constraint("sigma > 0") + >>> def check_sigma_positive(self): + >>> return self.sigma > 0 + """ + + def decorator(func: Callable[[Any], bool]) -> ParametrizationConstraintProtocol: + @wraps(func) + def wrapper(*args, **kwargs): # type: ignore + return func(*args, **kwargs) + + wrapper._is_constraint = True # type: ignore + wrapper._constraint_description = description # type: ignore + return wrapper # type: ignore + + return decorator + + +def parametrization( + family: ParametricFamily, name: str +) -> Callable[[type[Parametrization]], type[Parametrization]]: + """ + Decorator to register a class as a parametrization for a family. + + Parameters + ---------- + family : ParametricFamily + The family to register the parametrization with. + name : str + Name of the parametrization. + base : bool, optional + Whether this is the base parametrization, by default False. + + Returns + ------- + Callable + Decorator function that registers the class as a parametrization. + + Examples + -------- + >>> @parametrization(family=NormalFamily, name='meanvar') + >>> class MeanVarParametrization: + >>> mean: float + >>> var: float + """ + + def decorator(cls: type[Parametrization]) -> type[Parametrization]: + # Convert to dataclass if not already + if not hasattr(cls, "__dataclass_fields__"): + cls = dataclass(cls) + + # Add name property + def name_property(self): # type: ignore + return name + + cls.name = property(name_property) # type: ignore + + # Add parameters property + def parameters_property(self): # type: ignore + return { + field.name: getattr(self, field.name) + for field in self.__dataclass_fields__.values() + } + + cls.parameters = property(parameters_property) # type: ignore + + # Collect constraints + constraints = [] + for attr_name in dir(cls): + attr = getattr(cls, attr_name) + if hasattr(attr, "_is_constraint") and attr._is_constraint: + constraints.append( + ParametrizationConstraint(description=attr._constraint_description, check=attr) + ) + cls._constraints = constraints + + # Add validate method + cls.validate = Parametrization.validate # type: ignore + + # Register with family + family.parametrizations.add_parametrization(name, cls) + + return cls + + return decorator diff --git a/src/pysatl_core/families/registry.py b/src/pysatl_core/families/registry.py new file mode 100644 index 0000000..ba12dc6 --- /dev/null +++ b/src/pysatl_core/families/registry.py @@ -0,0 +1,99 @@ +""" +Global registry for parametric distribution families using singleton pattern. + +This module implements a centralized registry that maintains references to all +defined parametric families, enabling easy access and management across the +application. The registry follows the singleton pattern to ensure consistency. +""" + +from __future__ import annotations + +__author__ = "Leonid Elkin, Mikhail, Mikhailov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from typing import TYPE_CHECKING, ClassVar + +if TYPE_CHECKING: + from .parametric_family import ParametricFamily + + +class ParametricFamilyRegister: + """ + Singleton registry for parametric distribution families. + + This class maintains a global registry of all parametric families, + allowing them to be accessed by name from anywhere in the codebase. + + Examples + -------- + >>> registry = ParametricFamilyRegister() + >>> family = registry.get('Lognormal Family') + """ + + _instance: ClassVar[ParametricFamilyRegister | None] = None + _registered_families: dict[str, ParametricFamily] + + def __new__(cls) -> ParametricFamilyRegister: + """ + Create or return the singleton instance. + + Returns + ------- + ParametricFamiliesRegister + The singleton registry instance. + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._registered_families = {} + return cls._instance + + @classmethod + def get(cls, name: str) -> ParametricFamily: + """ + Retrieve a parametric family by name. + + Parameters + ---------- + name : str + The name of the family to retrieve. + + Returns + ------- + ParametricFamily + The requested parametric family. + + Raises + ------ + ValueError + If no family with the given name exists in the registry. + """ + self = cls() + if name not in self._registered_families: + raise ValueError(f"No family {name} found in register") + return self._registered_families[name] + + @classmethod + def register(cls, family: ParametricFamily) -> None: + """ + Register a new parametric family. + + Parameters + ---------- + family : ParametricFamily + The family to register. + + Raises + ------ + ValueError + If family with the same name already was registered + """ + self = cls() + if family.name in self._registered_families: + raise ValueError(f"Family {family.name} already found in register") + self._registered_families[family.name] = family + + +def _reset_families_register_for_tests() -> None: + """Reset the cached distribution type register (test helper).""" + ParametricFamilyRegister._instance = None diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index 449151f..a533c8e 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -44,14 +44,21 @@ class EuclideanDistributionType(DistributionType): dimension: int +UnivariateContinuous = EuclideanDistributionType(kind=Kind.CONTINUOUS, dimension=1) +UnivariateDiscrete = EuclideanDistributionType(kind=Kind.DISCRETE, dimension=1) + type GenericCharacteristicName = str +type ParametrizationName = str ScalarFunc = Callable[[float], float] __all__ = [ "Kind", "EuclideanDistributionType", + "UnivariateContinuous", + "UnivariateDiscrete", "GenericCharacteristicName", + "ParametrizationName", "DistributionType", "ScalarFunc", ] diff --git a/tests/conftest.py b/tests/conftest.py index 9ef5190..f8c3e4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,12 +6,13 @@ import pytest from pysatl_core.distributions.registry import _reset_distribution_type_register_for_tests +from pysatl_core.families.registry import _reset_families_register_for_tests pytest.importorskip("scipy") @pytest.fixture(autouse=True) -def _fresh_registry() -> Generator[None, Any, None]: +def _fresh_registries() -> Generator[None, Any, None]: _reset_distribution_type_register_for_tests() + _reset_families_register_for_tests() yield - _reset_distribution_type_register_for_tests() diff --git a/tests/test_family_and_distribution.py b/tests/test_family_and_distribution.py new file mode 100644 index 0000000..9f45f69 --- /dev/null +++ b/tests/test_family_and_distribution.py @@ -0,0 +1,112 @@ +from pysatl_core.families import ( + ParametricFamily, + ParametricFamilyDistribution, + ParametricFamilyRegister, + Parametrization, +) + +PDF = "pdf" + + +class TestParametricFamily: + """Test the ParametricFamily class.""" + + def test_family_creation(self): + """Test creating a parametric family.""" + + # Mock strategies + class MockSamplingStrategy: + pass + + class MockComputationStrategy: + pass + + # Mock characteristic function + def mock_pdf(parameters, x): + return x * 2 + + # Create family + family = ParametricFamily( + name="TestFamily", + distr_type="Continuous", + distr_parametrizations=["mock"], + distr_characteristics={PDF: mock_pdf}, + sampling_strategy=MockSamplingStrategy(), + computation_strategy=MockComputationStrategy(), + ) + register = ParametricFamilyRegister() + register.register(family) + + # Check properties + assert family.name == "TestFamily" + assert family._distr_type == "Continuous" + assert PDF in family.distr_characteristics + assert isinstance(family.sampling_strategy, MockSamplingStrategy) + assert isinstance(family.computation_strategy, MockComputationStrategy) + + +class TestParametricFamilyDistribution: + """Test the ParametricFamilyDistribution class.""" + + def test_distribution_creation(self): + """Test creating a distribution instance.""" + + # Create a mock family + class MockSamplingStrategy: + def sample(self, n, distr, **options): + return [1, 2, 3] # Mock samples + + class MockComputationStrategy: + pass + + def mock_pdf(parameters, x): + return x * parameters.value + + family = ParametricFamily( + name="MockFamily", + distr_type="Continuous", + distr_parametrizations=["mock"], + distr_characteristics={PDF: mock_pdf}, + sampling_strategy=MockSamplingStrategy(), + computation_strategy=MockComputationStrategy(), + ) + + # Create a mock parametrization + class MockParametrization(Parametrization): + def __init__(self, value): + self.value = value + + @property + def name(self): + return "mock" + + @property + def parameters(self): + return {"value": self.value} + + # Add to family + family.parametrizations.add_parametrization("mock", MockParametrization) + + # Register family + register = ParametricFamilyRegister() + register.register(family) + + # Create distribution + params = MockParametrization(2.0) + + dist = ParametricFamilyDistribution("MockFamily", "Continuous", params) + + # Check properties + assert dist.distr_name == "MockFamily" + assert dist.distribution_type == "Continuous" + assert dist.parameters is params + assert dist.family is family + + # Test sampling + samples = dist.sample(3) + assert samples == [1, 2, 3] + + # Test analytical computations + computations = dist.analytical_computations + assert PDF in computations + assert computations[PDF].func(5.0) == 10.0 # 5.0 * 2.0 diff --git a/tests/test_parameters.py b/tests/test_parameters.py new file mode 100644 index 0000000..8a79b69 --- /dev/null +++ b/tests/test_parameters.py @@ -0,0 +1,187 @@ +import pytest + +from pysatl_core.families import ( + ParametricFamily, + Parametrization, + ParametrizationConstraint, + ParametrizationSpec, + constraint, + parametrization, +) + + +class TestParametrizationConstraint: + """Test the ParametrizationConstraint class.""" + + def test_constraint_creation(self): + """Test creating a constraint with description and check function.""" + + def check_func(obj): + return obj.value > 0 + + constraint = ParametrizationConstraint( + description="Value must be positive", check=check_func + ) + + assert constraint.description == "Value must be positive" + assert constraint.check is check_func + + +class TestParametrization: + """Test the base Parametrization class.""" + + def test_abstract_methods(self): + """Test that Parametrization is abstract and requires name and parameters properties.""" + with pytest.raises(TypeError): + Parametrization() # Can't instantiate abstract class + + # Create a concrete implementation + class ConcreteParametrization(Parametrization): + @property + def name(self): + return "concrete" + + @property + def parameters(self): + return {"param": 1.0} + + # Should be able to instantiate + param = ConcreteParametrization() + assert param.name == "concrete" + assert param.parameters == {"param": 1.0} + + +class TestParametrizationSpec: + """Test the ParametrizationSpec class.""" + + def test_add_and_get_parametrization(self): + """Test adding and retrieving parametrizations.""" + spec = ParametrizationSpec(base_name="mock") + + # Create a mock parametrization class + class MockParametrization(Parametrization): + @property + def name(self): + return "mock" + + @property + def parameters(self): + return {} + + # Add parametrization + spec.add_parametrization("mock", MockParametrization) + + # Check it was added + assert "mock" in spec.parametrizations + assert spec.parametrizations["mock"] is MockParametrization + assert spec.base_parametrization_name == "mock" + assert spec.base is MockParametrization + + def test_get_base_parameters(self): + """Test converting parameters to base parametrization.""" + spec = ParametrizationSpec(base_name="base") + + # Create mock parametrizations + class BaseParametrization(Parametrization): + def __init__(self, value): + self.value = value + + @property + def name(self): + return "base" + + @property + def parameters(self): + return {"value": self.value} + + class OtherParametrization(Parametrization): + def __init__(self, other_value): + self.other_value = other_value + + @property + def name(self): + return "other" + + @property + def parameters(self): + return {"other_value": self.other_value} + + def transform_to_base_parametrization(self): + return BaseParametrization(self.other_value * 2) + + # Add parametrizations + spec.add_parametrization("base", BaseParametrization) + spec.add_parametrization("other", OtherParametrization) + + # Test with base parametrization + base_params = BaseParametrization(5.0) + result = spec.get_base_parameters(base_params) + assert result is base_params + + # Test with other parametrization + other_params = OtherParametrization(3.0) + result = spec.get_base_parameters(other_params) + assert isinstance(result, BaseParametrization) + assert result.value == 6.0 # 3.0 * 2 + + +class TestDecorators: + """Test the constraint and parametrization decorators.""" + + def test_constraint_decorator(self): + """Test the constraint decorator.""" + + @constraint("Value must be positive") + def check_positive(self): + return self.value > 0 + + # Check that the decorator added the required attributes + assert hasattr(check_positive, "_is_constraint") + assert hasattr(check_positive, "_constraint_description") + assert check_positive._is_constraint is True + assert check_positive._constraint_description == "Value must be positive" + + def test_parametrization_decorator(self): + """Test the parametrization decorator.""" + # Create a mock family + family = ParametricFamily( + name="TestFamily", + distr_type="Continuous", + distr_parametrizations=["test"], + distr_characteristics={}, + sampling_strategy=None, + computation_strategy=None, + ) + + # Apply the decorator + @parametrization(family=family, name="test") + class TestParametrization: + value: float + + @constraint("Value must be positive") + def check_positive(self): + return self.value > 0 + + # Check that the class was modified + assert hasattr(TestParametrization, "name") + assert hasattr(TestParametrization, "parameters") + assert hasattr(TestParametrization, "_constraints") + assert hasattr(TestParametrization, "validate") + + # Check that it was registered with the family + assert "test" in family.parametrizations.parametrizations + assert family.parametrizations.parametrizations["test"] is TestParametrization + assert family.parametrizations.base_parametrization_name == "test" + + # Test instantiation and validation + instance = TestParametrization(value=5.0) + assert instance.name == "test" + assert instance.parameters == {"value": 5.0} + + # This should not raise an exception + instance.validate() + + # Test with invalid parameters + invalid_instance = TestParametrization(value=-1.0) + with pytest.raises(ValueError, match="Constraint.*does not hold"): + invalid_instance.validate() diff --git a/tests/test_parametric_families_registry.py b/tests/test_parametric_families_registry.py new file mode 100644 index 0000000..3c931e5 --- /dev/null +++ b/tests/test_parametric_families_registry.py @@ -0,0 +1,33 @@ +import pytest + +from pysatl_core.families import ParametricFamilyRegister + + +class TestParametricFamiliesRegister: + """Test the ParametricFamiliesRegister singleton.""" + + def test_singleton_pattern(self): + """Test that only one instance exists.""" + register1 = ParametricFamilyRegister() + register2 = ParametricFamilyRegister() + assert register1 is register2 + + def test_register_and_get_family(self): + """Test registering and retrieving a family.""" + register = ParametricFamilyRegister() + + # Create a mock family + mock_family = type("MockFamily", (), {"name": "TestFamily"})() + + # Register and retrieve + register.register(mock_family) + retrieved = register.get("TestFamily") + + assert retrieved is mock_family + + def test_get_nonexistent_family(self): + """Test error when getting a non-existent family.""" + register = ParametricFamilyRegister() + + with pytest.raises(ValueError, match="No family Nonexistent found in register"): + register.get("Nonexistent")