From 347945ec71b7fc7c14274ea9b8bce6a994e7b30c Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:09:11 -0700 Subject: [PATCH] support --- distreqx/distributions/_bernoulli.py | 9 +++++++++ distreqx/distributions/_beta.py | 6 ++++++ distreqx/distributions/_categorical.py | 12 ++++++++++++ distreqx/distributions/_distribution.py | 6 ++++++ distreqx/distributions/_gamma.py | 6 ++++++ distreqx/distributions/_independent.py | 5 +++++ distreqx/distributions/_logistic.py | 6 ++++++ distreqx/distributions/_mixture_same_family.py | 13 +++++++++++++ distreqx/distributions/_mvn_from_bijector.py | 6 ++++++ distreqx/distributions/_normal.py | 6 ++++++ distreqx/distributions/_one_hot_categorical.py | 6 ++++++ distreqx/distributions/_transformed.py | 4 ++++ distreqx/distributions/_uniform.py | 9 +++++++++ pyproject.toml | 2 +- tests/abstractdistribution_test.py | 8 ++++++++ 15 files changed, 103 insertions(+), 1 deletion(-) diff --git a/distreqx/distributions/_bernoulli.py b/distreqx/distributions/_bernoulli.py index 12e7ef7..01be5a8 100644 --- a/distreqx/distributions/_bernoulli.py +++ b/distreqx/distributions/_bernoulli.py @@ -76,6 +76,15 @@ def probs(self) -> Array: def event_shape(self) -> tuple[int, ...]: return self.probs.shape + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`. + + The Bernoulli is discrete on `{0, 1}`. + """ + dtype = self.probs.dtype + return (jnp.array(0.0, dtype=dtype), jnp.array(1.0, dtype=dtype)) + def _log_probs_parameter(self) -> tuple[Array, Array]: if self._logits is None: if self._probs is None: diff --git a/distreqx/distributions/_beta.py b/distreqx/distributions/_beta.py index 741a0cd..ef61d27 100644 --- a/distreqx/distributions/_beta.py +++ b/distreqx/distributions/_beta.py @@ -61,6 +61,12 @@ def __init__( def event_shape(self) -> tuple: return () + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`.""" + dtype = jnp.result_type(self.alpha, self.beta) + return (jnp.array(0.0, dtype=dtype), jnp.array(1.0, dtype=dtype)) + def sample(self, key: Key[Array, ""]) -> Array: """See `Distribution.sample`.""" dtype = jnp.result_type(self.alpha, self.beta) diff --git a/distreqx/distributions/_categorical.py b/distreqx/distributions/_categorical.py index cc33a6f..54d284d 100644 --- a/distreqx/distributions/_categorical.py +++ b/distreqx/distributions/_categorical.py @@ -55,6 +55,18 @@ def event_shape(self) -> tuple: """Shape of event of distribution samples.""" return () + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`. + + The Categorical is discrete on `{0, ..., K-1}`. + """ + dtype = self.probs.dtype + return ( + jnp.array(0.0, dtype=dtype), + jnp.array(self.num_categories - 1, dtype=dtype), + ) + @property def logits(self) -> Array: """The logits for each event.""" diff --git a/distreqx/distributions/_distribution.py b/distreqx/distributions/_distribution.py index 1e3c142..f288cc7 100644 --- a/distreqx/distributions/_distribution.py +++ b/distreqx/distributions/_distribution.py @@ -56,6 +56,12 @@ def event_shape(self) -> EventT: """Shape of event of distribution samples.""" raise NotImplementedError + @property + @abstractmethod + def support(self) -> tuple[Array, Array]: + """Range `(lower, upper)` spanning the distribution's support.""" + raise NotImplementedError + @property def dtype(self) -> jnp.dtype: """Data type of a sample""" diff --git a/distreqx/distributions/_gamma.py b/distreqx/distributions/_gamma.py index c20142f..5ca8ef9 100644 --- a/distreqx/distributions/_gamma.py +++ b/distreqx/distributions/_gamma.py @@ -53,6 +53,12 @@ def __init__( def event_shape(self) -> tuple: return () + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`.""" + dtype = jnp.result_type(self.concentration, self.rate) + return (jnp.array(0.0, dtype=dtype), jnp.array(jnp.inf, dtype=dtype)) + def sample(self, key: Key[Array, ""]) -> Array: """See `Distribution.sample`.""" dtype = jnp.result_type(self.concentration, self.rate) diff --git a/distreqx/distributions/_independent.py b/distreqx/distributions/_independent.py index 3577ff5..5f4dcb0 100644 --- a/distreqx/distributions/_independent.py +++ b/distreqx/distributions/_independent.py @@ -53,6 +53,11 @@ def event_shape(self) -> tuple: """Shape of event of distribution samples.""" return self.distribution.event_shape + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`.""" + return self.distribution.support + def sample(self, key: Key[Array, ""]) -> Array: """See `Distribution.sample`.""" return self.distribution.sample(key) diff --git a/distreqx/distributions/_logistic.py b/distreqx/distributions/_logistic.py index 71eee6b..e8a6c3a 100644 --- a/distreqx/distributions/_logistic.py +++ b/distreqx/distributions/_logistic.py @@ -31,6 +31,12 @@ def event_shape(self) -> tuple[int, ...]: """Shape of event of distribution samples.""" return self.loc.shape + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`.""" + dtype = jnp.result_type(self.loc, self.scale) + return (jnp.array(-jnp.inf, dtype=dtype), jnp.array(jnp.inf, dtype=dtype)) + def _standardize(self, value: Array) -> Array: return (value - self.loc) / self.scale diff --git a/distreqx/distributions/_mixture_same_family.py b/distreqx/distributions/_mixture_same_family.py index f2da9d9..73cc349 100644 --- a/distreqx/distributions/_mixture_same_family.py +++ b/distreqx/distributions/_mixture_same_family.py @@ -48,6 +48,19 @@ def event_shape(self): """Shape of event of distribution samples.""" return self.components_distribution.event_shape + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`. + + The mixture's support is the union of its components' supports. + """ + lower, upper = self.components_distribution.support + if jnp.ndim(lower): + lower = jnp.min(lower, axis=0) + if jnp.ndim(upper): + upper = jnp.max(upper, axis=0) + return lower, upper + def sample(self, key) -> Array: """See `AbstractDistribution._sample`.""" key_mix, key_components = jax.random.split(key) diff --git a/distreqx/distributions/_mvn_from_bijector.py b/distreqx/distributions/_mvn_from_bijector.py index 9cf7c86..4ee3448 100644 --- a/distreqx/distributions/_mvn_from_bijector.py +++ b/distreqx/distributions/_mvn_from_bijector.py @@ -39,6 +39,12 @@ class AbstractMultivariateNormalFromBijector(AbstractTransformed): loc: eqx.AbstractVar[Array] scale: eqx.AbstractVar[AbstractLinearBijector] + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`.""" + dtype = self.loc.dtype + return (jnp.array(-jnp.inf, dtype=dtype), jnp.array(jnp.inf, dtype=dtype)) + def mean(self) -> Array: """Calculates the mean.""" return self.loc diff --git a/distreqx/distributions/_normal.py b/distreqx/distributions/_normal.py index bf5cacb..7cdd07a 100644 --- a/distreqx/distributions/_normal.py +++ b/distreqx/distributions/_normal.py @@ -33,6 +33,12 @@ def event_shape(self) -> tuple[int, ...]: """Shape of event of distribution samples.""" return self.loc.shape + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`.""" + dtype = jnp.result_type(self.loc, self.scale) + return (jnp.array(-jnp.inf, dtype=dtype), jnp.array(jnp.inf, dtype=dtype)) + def _sample_from_std_normal(self, key: Key[Array, ""]) -> Array: dtype = jnp.result_type(self.loc, self.scale) return jax.random.normal(key, shape=self.event_shape, dtype=dtype) diff --git a/distreqx/distributions/_one_hot_categorical.py b/distreqx/distributions/_one_hot_categorical.py index 6358734..9cc39ce 100644 --- a/distreqx/distributions/_one_hot_categorical.py +++ b/distreqx/distributions/_one_hot_categorical.py @@ -50,6 +50,12 @@ def event_shape(self) -> tuple: """Shape of event of distribution samples.""" return (self.num_categories,) + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`.""" + dtype = self.probs.dtype + return (jnp.array(0.0, dtype=dtype), jnp.array(1.0, dtype=dtype)) + @property def logits(self) -> Array: """The logits for each event.""" diff --git a/distreqx/distributions/_transformed.py b/distreqx/distributions/_transformed.py index e56d94c..f0ac414 100644 --- a/distreqx/distributions/_transformed.py +++ b/distreqx/distributions/_transformed.py @@ -199,6 +199,10 @@ def mode(self) -> Array: def icdf(self, value: PyTree[Array]) -> PyTree[Array]: raise NotImplementedError + @property + def support(self) -> tuple[Array, Array]: + raise NotImplementedError + def log_cdf(self, value: PyTree[Array]) -> PyTree[Array]: raise NotImplementedError diff --git a/distreqx/distributions/_uniform.py b/distreqx/distributions/_uniform.py index a5e1bf8..d059bbb 100644 --- a/distreqx/distributions/_uniform.py +++ b/distreqx/distributions/_uniform.py @@ -24,6 +24,15 @@ def range(self) -> Array: def event_shape(self) -> tuple[int, ...]: return self.low.shape + @property + def support(self) -> tuple[Array, Array]: + """See `Distribution.support`. + + Unlike most distributions, the bounds vary per element and are + returned as arrays of shape `event_shape`. + """ + return (self.low, self.high) + def sample(self, key: Key[Array, ""]) -> Array: """See `Distribution.sample`.""" uniform = jax.random.uniform( diff --git a/pyproject.toml b/pyproject.toml index 9a22a48..359aa8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" requires-python = ">=3.10" license = {file = "LICENSE"} authors = [ - {name = "Owen Lockwood"}, + {name = "lockwo"}, ] keywords = ["jax", "probability", "distributions", "equinox", "machine-learning"] classifiers = [ diff --git a/tests/abstractdistribution_test.py b/tests/abstractdistribution_test.py index 4be18a2..6e8b959 100644 --- a/tests/abstractdistribution_test.py +++ b/tests/abstractdistribution_test.py @@ -31,6 +31,10 @@ def log_prob(self, value): def event_shape(self): return (1,) + @property + def support(self): + raise NotImplementedError + def entropy(self): raise NotImplementedError @@ -71,6 +75,10 @@ def log_prob(self, value): def event_shape(self): return self._dimension + @property + def support(self): + raise NotImplementedError + def entropy(self): raise NotImplementedError