Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions distreqx/distributions/_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ def probs(self) -> Array:
def event_shape(self) -> tuple[int, ...]:
return self.probs.shape

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`.

The Bernoulli is discrete on `{0, 1}`.
"""
dtype = self.probs.dtype
return (jnp.array(0.0, dtype=dtype), jnp.array(1.0, dtype=dtype))

def _log_probs_parameter(self) -> tuple[Array, Array]:
if self._logits is None:
if self._probs is None:
Expand Down
6 changes: 6 additions & 0 deletions distreqx/distributions/_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def __init__(
def event_shape(self) -> tuple:
return ()

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`."""
dtype = jnp.result_type(self.alpha, self.beta)
return (jnp.array(0.0, dtype=dtype), jnp.array(1.0, dtype=dtype))

def sample(self, key: Key[Array, ""]) -> Array:
"""See `Distribution.sample`."""
dtype = jnp.result_type(self.alpha, self.beta)
Expand Down
12 changes: 12 additions & 0 deletions distreqx/distributions/_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def event_shape(self) -> tuple:
"""Shape of event of distribution samples."""
return ()

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`.

The Categorical is discrete on `{0, ..., K-1}`.
"""
dtype = self.probs.dtype
return (
jnp.array(0.0, dtype=dtype),
jnp.array(self.num_categories - 1, dtype=dtype),
)

@property
def logits(self) -> Array:
"""The logits for each event."""
Expand Down
6 changes: 6 additions & 0 deletions distreqx/distributions/_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def event_shape(self) -> EventT:
"""Shape of event of distribution samples."""
raise NotImplementedError

@property
@abstractmethod
def support(self) -> tuple[Array, Array]:
"""Range `(lower, upper)` spanning the distribution's support."""
raise NotImplementedError

@property
def dtype(self) -> jnp.dtype:
"""Data type of a sample"""
Expand Down
6 changes: 6 additions & 0 deletions distreqx/distributions/_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def __init__(
def event_shape(self) -> tuple:
return ()

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`."""
dtype = jnp.result_type(self.concentration, self.rate)
return (jnp.array(0.0, dtype=dtype), jnp.array(jnp.inf, dtype=dtype))

def sample(self, key: Key[Array, ""]) -> Array:
"""See `Distribution.sample`."""
dtype = jnp.result_type(self.concentration, self.rate)
Expand Down
5 changes: 5 additions & 0 deletions distreqx/distributions/_independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def event_shape(self) -> tuple:
"""Shape of event of distribution samples."""
return self.distribution.event_shape

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`."""
return self.distribution.support

def sample(self, key: Key[Array, ""]) -> Array:
"""See `Distribution.sample`."""
return self.distribution.sample(key)
Expand Down
6 changes: 6 additions & 0 deletions distreqx/distributions/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def event_shape(self) -> tuple[int, ...]:
"""Shape of event of distribution samples."""
return self.loc.shape

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`."""
dtype = jnp.result_type(self.loc, self.scale)
return (jnp.array(-jnp.inf, dtype=dtype), jnp.array(jnp.inf, dtype=dtype))

def _standardize(self, value: Array) -> Array:
return (value - self.loc) / self.scale

Expand Down
13 changes: 13 additions & 0 deletions distreqx/distributions/_mixture_same_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def event_shape(self):
"""Shape of event of distribution samples."""
return self.components_distribution.event_shape

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`.

The mixture's support is the union of its components' supports.
"""
lower, upper = self.components_distribution.support
if jnp.ndim(lower):
lower = jnp.min(lower, axis=0)
if jnp.ndim(upper):
upper = jnp.max(upper, axis=0)
return lower, upper

def sample(self, key) -> Array:
"""See `AbstractDistribution._sample`."""
key_mix, key_components = jax.random.split(key)
Expand Down
6 changes: 6 additions & 0 deletions distreqx/distributions/_mvn_from_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class AbstractMultivariateNormalFromBijector(AbstractTransformed):
loc: eqx.AbstractVar[Array]
scale: eqx.AbstractVar[AbstractLinearBijector]

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`."""
dtype = self.loc.dtype
return (jnp.array(-jnp.inf, dtype=dtype), jnp.array(jnp.inf, dtype=dtype))

def mean(self) -> Array:
"""Calculates the mean."""
return self.loc
Expand Down
6 changes: 6 additions & 0 deletions distreqx/distributions/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def event_shape(self) -> tuple[int, ...]:
"""Shape of event of distribution samples."""
return self.loc.shape

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`."""
dtype = jnp.result_type(self.loc, self.scale)
return (jnp.array(-jnp.inf, dtype=dtype), jnp.array(jnp.inf, dtype=dtype))

def _sample_from_std_normal(self, key: Key[Array, ""]) -> Array:
dtype = jnp.result_type(self.loc, self.scale)
return jax.random.normal(key, shape=self.event_shape, dtype=dtype)
Expand Down
6 changes: 6 additions & 0 deletions distreqx/distributions/_one_hot_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def event_shape(self) -> tuple:
"""Shape of event of distribution samples."""
return (self.num_categories,)

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`."""
dtype = self.probs.dtype
return (jnp.array(0.0, dtype=dtype), jnp.array(1.0, dtype=dtype))

@property
def logits(self) -> Array:
"""The logits for each event."""
Expand Down
4 changes: 4 additions & 0 deletions distreqx/distributions/_transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def mode(self) -> Array:
def icdf(self, value: PyTree[Array]) -> PyTree[Array]:
raise NotImplementedError

@property
def support(self) -> tuple[Array, Array]:
raise NotImplementedError

def log_cdf(self, value: PyTree[Array]) -> PyTree[Array]:
raise NotImplementedError

Expand Down
9 changes: 9 additions & 0 deletions distreqx/distributions/_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def range(self) -> Array:
def event_shape(self) -> tuple[int, ...]:
return self.low.shape

@property
def support(self) -> tuple[Array, Array]:
"""See `Distribution.support`.

Unlike most distributions, the bounds vary per element and are
returned as arrays of shape `event_shape`.
"""
return (self.low, self.high)

def sample(self, key: Key[Array, ""]) -> Array:
"""See `Distribution.sample`."""
uniform = jax.random.uniform(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "Owen Lockwood"},
{name = "lockwo"},
]
keywords = ["jax", "probability", "distributions", "equinox", "machine-learning"]
classifiers = [
Expand Down
8 changes: 8 additions & 0 deletions tests/abstractdistribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def log_prob(self, value):
def event_shape(self):
return (1,)

@property
def support(self):
raise NotImplementedError

def entropy(self):
raise NotImplementedError

Expand Down Expand Up @@ -71,6 +75,10 @@ def log_prob(self, value):
def event_shape(self):
return self._dimension

@property
def support(self):
raise NotImplementedError

def entropy(self):
raise NotImplementedError

Expand Down
Loading