"""Autograd based differentiable distribution classes"""
import logging
from typing import Any, Callable, Dict, List, Literal, Union
import autograd.numpy as np
from autograd.scipy.special import gammaln
from autograd.scipy.stats.norm import logpdf
from autograd.scipy.stats.poisson import logpmf
from scipy.stats import multivariate_normal, norm, poisson
from spey import log_once
from spey.system.exceptions import DistributionError
# pylint: disable=E1101, W1203, E1121
log = logging.getLogger("Spey")
_LOG_2PI = np.log(2.0 * np.pi)
__all__ = ["Poisson", "Normal", "MultivariateNormal", "MainModel", "ConstraintModel"]
def __dir__():
return __all__
[docs]
class Poisson:
"""Poisson distribution"""
[docs]
def __init__(self, loc: np.ndarray):
# ! Clip for numeric stability, poisson can not take negative values
self.loc = np.clip(loc, 1e-20, None)
def expected_data(self) -> np.ndarray:
"""The expectation value of the Poisson distribution."""
return np.array(self.loc)
def sample(self, sample_size: int) -> np.ndarray:
"""Generate samples"""
shape = [sample_size]
if isinstance(self.loc, np.ndarray):
shape += [len(self.loc)]
return poisson(self.loc).rvs(size=shape)
def log_prob(self, value: np.ndarray) -> np.ndarray:
"""Compute log-probability"""
# for code efficiency
if np.array(value).dtype in [np.int32, np.int16, np.int64]:
return logpmf(value, self.loc).astype(np.float64)
return (value * np.log(self.loc) - self.loc - gammaln(value + 1.0)).astype(
np.float64
)
[docs]
class Normal:
"""
Normal distribution
Args:
loc (``np.ndarray``): Mean of the distribution.
scale (``np.ndarray``): standard deviation.
domain (``slice``, default ``slice(None, None)``): set of parameters to be used within
the distribution.
"""
[docs]
def __init__(
self,
loc: np.ndarray,
scale: np.ndarray,
domain: slice = slice(None, None),
):
self.loc = loc
self.domain = domain
"""Which parameters should be used during the computation of the pdf"""
if callable(scale):
self.scale = scale
else:
self.scale = lambda pars: scale
def expected_data(self) -> np.ndarray:
"""The expectation value of the Normal distribution."""
return np.array(self.loc)
def sample(self, value: np.ndarray, sample_size: int) -> np.ndarray:
"""Generate samples"""
shape = [sample_size]
if isinstance(self.loc, np.ndarray):
shape += [len(self.loc)]
return norm(self.loc, self.scale(value[self.domain])).rvs(size=shape)
def log_prob(self, value: np.ndarray) -> np.ndarray:
"""Compute log-probability"""
x = value[self.domain]
return logpdf(x, self.loc, self.scale(x))
[docs]
class MultivariateNormal:
"""
Multivariate normal distribution
Args:
mean (``np.ndarray``): Mean of the distribution.
cov (``np.ndarray``): Symmetric positive (semi)definite
covariance matrix of the distribution.
domain (``slice``, default ``slice(None, None)``): set of parameters to be used within
the distribution.
"""
[docs]
def __init__(
self,
mean: np.ndarray,
cov: Union[np.ndarray, Callable[[np.ndarray], np.ndarray]],
domain: slice = slice(None, None),
):
self.mean = mean
"""Mean of the distribution."""
self.cov = cov if callable(cov) else lambda val: cov
"""Symmetric positive (semi)definite covariance matrix of the distribution."""
self.domain = domain
"""Which parameters should be used during the computation of the pdf"""
if callable(cov):
self._inv_cov = lambda val: np.linalg.inv(cov(val))
self._logdet_cov = lambda val: np.clip(
np.linalg.slogdet(cov(val))[1], 1e-20, None
)
else:
# for code efficiency
inv = np.linalg.inv(cov)
logdet = np.linalg.slogdet(cov)[1]
if np.isinf(logdet):
log_once(
"det(cov) is infinite, this might cause numeric problems. "
"Please check the covariance matrix.",
log_type="error",
)
self._inv_cov = lambda val: inv
self._logdet_cov = lambda val: logdet
def expected_data(self) -> np.ndarray:
"""The expectation value of the Multivariate Normal distribution."""
return self.mean
def sample(self, value: np.ndarray, sample_size: int) -> np.ndarray:
"""Generate samples"""
return multivariate_normal(self.mean, self.cov(value[self.domain])).rvs(
size=(sample_size,)
)
def log_prob(self, value: np.ndarray) -> np.ndarray:
"""Compute log-probability"""
# NOTE: The reason for not going with multivariate_norm logpdf is two folds
# 1) open computation allows for logdet and inverse to be precomputed once
# 2) Scipy has an issue with its logdet value. Inside multivariate normal
# there is cov_object created and through that object "pseudo logdet is
# being computed". This value should match with `np.linalg.slogdet(cov)`
# or `np.log(np.prod(np.linalg.eig(cov)[0]))` however it is sligtly different.
x = value[self.domain]
var = x - self.mean
return (
-0.5 * (var @ self._inv_cov(x) @ var)
- 0.5 * (len(x) * _LOG_2PI + self._logdet_cov(x))
).astype(np.float64)
[docs]
class MainModel:
"""
Main statistical model, modelled as Poisson distribution which has a
variable lambda.
Args:
loc (``Callable[[np.ndarray], np.ndarray]``): callable function that represents
lambda values of poisson distribution. It takes nuisance parameters as input.
"""
[docs]
def __init__(
self,
loc: Callable[[np.ndarray], np.ndarray],
cov: Union[np.ndarray, Callable[[np.ndarray], np.ndarray]] = None,
pdf_type: Literal["poiss", "normal", "multivariate_normal"] = "poiss",
):
self.pdf_type = pdf_type
"""Type of the PDF"""
if pdf_type == "poiss":
self._pdf = lambda pars: Poisson(loc(pars))
elif pdf_type == "normal" and cov is not None:
if callable(cov):
self._pdf = lambda pars: Normal(loc=loc(pars), scale=cov(pars))
else:
self._pdf = lambda pars: Normal(loc=loc(pars), scale=cov)
elif pdf_type == "multivariate_normal" and cov is not None:
if callable(cov):
# Callable cov: may vary with pars, so MultivariateNormal (and its
# O(n³) inv/logdet ops) must be reconstructed on every call.
self._pdf = lambda pars: MultivariateNormal(mean=loc(pars), cov=cov(pars))
else:
# Non-callable cov: inv(cov) and logdet(cov) are constant.
# Create a single MultivariateNormal once so those O(n³) ops run
# exactly once at construction time. Each call only updates the
# mean (loc(pars)) in-place; _inv_cov and _logdet_cov stay cached.
_mv = MultivariateNormal(mean=np.zeros(len(cov)), cov=cov)
def _static_cov_pdf(pars):
_mv.mean = loc(pars)
return _mv
self._pdf = _static_cov_pdf
else:
raise DistributionError("Unknown pdf type or associated input.")
def expected_data(self, pars: np.ndarray) -> np.ndarray:
"""The expectation value of the main model."""
return self._pdf(pars).expected_data()
def sample(self, pars: np.ndarray, sample_size: int) -> np.ndarray:
r"""
Generate samples
Args:
pars (``np.ndarray``): parameter of interest and nuisance parameters
:math:`\mu` and :math:`\theta` combined.
sample_size (``int``): size of the sample to return
Returns:
``np.ndarray``:
sampled data
"""
if self.pdf_type == "poiss":
return self._pdf(pars).sample(sample_size)
return self._pdf(pars).sample(pars, sample_size)
def log_prob(self, pars: np.ndarray, data: np.ndarray) -> np.ndarray:
r"""
Compute log-probability
Args:
pars (``np.ndarray``): parameter of interest and nuisance parameters
:math:`\mu` and :math:`\theta` combined.
data (``np.ndarray``): actual data
Returns:
``np.ndarray``:
log-probability of the main model
"""
return np.sum(self._pdf(pars).log_prob(data))
[docs]
class ConstraintModel:
"""
Constraint term modelled as a Gaussian distribution.
Args:
pdf_descriptions (``List[Dict[Text, Any]]``): description of the pdf component.
Dictionary elements should contain two keywords
* ``"distribution_type"`` (``Text``): ``"normal"`` or ``"multivariatenormal"``
* ``"args"``: Input arguments for the distribution
* ``"kwargs"``: Input keyword arguments for the distribution
"""
[docs]
def __init__(self, pdf_descriptions: List[Dict[str, Any]]):
self._pdfs = []
distributions = {"normal": Normal, "multivariatenormal": MultivariateNormal}
log.debug("Adding constraint terms:")
for desc in pdf_descriptions:
assert desc["distribution_type"].lower() in [
"normal",
"multivariatenormal",
], f"Unknown distribution type: {desc['distribution_type']}"
log.debug(f"{desc}")
self._pdfs.append(
distributions[desc["distribution_type"]](
*desc.get("args", []), **desc.get("kwargs", {})
)
)
def __len__(self):
return len(self._pdfs)
def expected_data(self) -> np.ndarray:
"""The expectation value of the constraint model."""
if len(self) > 1:
return np.hstack([pdf.expected_data() for pdf in self._pdfs])
return self._pdfs[0].expected_data()
def sample(self, pars: np.ndarray, sample_size: int) -> np.ndarray:
r"""
Generate samples
Args:
pars (``np.ndarray``): parameter of interest and nuisance parameters
:math:`\mu` and :math:`\theta` combined.
sample_size (``int``): size of the sample to return
Returns:
``np.ndarray``:
sampled data
"""
if len(self) > 1:
return np.hstack([pdf.sample(pars, sample_size) for pdf in self._pdfs])
return self._pdfs[0].sample(pars, sample_size)
def log_prob(self, pars: np.ndarray) -> np.ndarray:
r"""
Compute log-probability
Args:
pars (``np.ndarray``): parameter of interest and nuisance parameters
:math:`\mu` and :math:`\theta` combined.
data (``np.ndarray``): actual data
Returns:
``np.ndarray``:
log-probability of the main model
"""
return sum(pdf.log_prob(pars).sum() for pdf in self._pdfs)
class MixtureModel:
"""
Generate probability distribution from combination of different distributions
Args:
args: Distributions
.. warning::
All distributions have been assumed to have same shape.
"""
def __init__(self, *args):
self.distributions = [d for d in args if hasattr(d, "log_prob")]
def __iter__(self):
yield from self.distributions
def sample(self, sample_shape: int) -> np.ndarray:
"""Generate samples"""
data = np.zeros(
(sample_shape, self.distributions[-1].shape, len(self.distributions))
)
for idx, dist in enumerate(self):
data[:, ..., idx] = dist.sample(sample_shape)
random_idx = np.random.choice(
np.arange(len(self.distributions)),
size=(sample_shape,),
p=[1.0 / len(self.distributions)] * len(self.distributions),
)
return data[np.arange(sample_shape), ..., random_idx]
def log_prob(self, value: np.ndarray) -> np.ndarray:
"""Compute log-probability"""
return np.sum([dist.log_prob(value) for dist in self.distributions])