From ae69a52bedf05770dc9a46960866ab66b70e4933 Mon Sep 17 00:00:00 2001 From: Alican Bozkurt Date: Tue, 26 Jun 2018 00:31:02 -0700 Subject: [PATCH 1/2] add binomial entropy and kl --- test/test_distributions.py | 4 +++- torch/distributions/binomial.py | 34 +++++++++++++++++++++++++++++---- torch/distributions/kl.py | 24 ++++++++++++++++++++++- 3 files changed, 56 insertions(+), 6 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index e47c16d2f0347..630fe01cce0dd 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -821,7 +821,6 @@ def test_binomial(self): self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p]) self._gradcheck_log_prob(lambda p: Binomial(total_count, None, p.log()), [p]) self.assertRaises(NotImplementedError, Binomial(10, p).rsample) - self.assertRaises(NotImplementedError, Binomial(10, p).entropy) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_binomial_log_prob(self): @@ -2757,6 +2756,7 @@ def __init__(self, probs): [0.33, 0.33, 0.34], [0.2, 0.2, 0.4]]) exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0]) + geometric = pairwise(Geometric, [0.1, 0.2, 0.6, 0.9]) gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0]) @@ -2792,6 +2792,8 @@ def __init__(self, probs): (beta, gamma), (beta, normal), (binomial30, binomial30), + (binomial30, poisson), + (binomial30, geometric), (binomial_vectorized_count, binomial_vectorized_count), (categorical, categorical), (chi2, chi2), diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 28756d405383f..1fcbe6a3e759a 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -77,6 +77,11 @@ def probs(self): def param_shape(self): return self._param.size() + def _log1pmprobs(self): + # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp())) + max_val = (-self.logits).clamp(min=0.0) + return max_val - torch.log1p((self.logits + 2 * max_val).exp()) + def sample(self, sample_shape=torch.Size()): with torch.no_grad(): max_count = max(int(self.total_count.max()), 1) @@ -94,11 +99,8 @@ def log_prob(self, value): log_factorial_n = torch.lgamma(self.total_count + 1) log_factorial_k = torch.lgamma(value + 1) log_factorial_nmk = torch.lgamma(self.total_count - value + 1) - max_val = (-self.logits).clamp(min=0.0) - # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp())) return (log_factorial_n - log_factorial_k - log_factorial_nmk + - value * self.logits + self.total_count * max_val - - self.total_count * torch.log1p((self.logits + 2 * max_val).exp())) + value * self.logits + self.total_count * self._log1pmprobs()) def enumerate_support(self): total_count = int(self.total_count.max()) @@ -109,3 +111,27 @@ def enumerate_support(self): values = values.view((-1,) + (1,) * len(self._batch_shape)) values = values.expand((-1,) + self._batch_shape) return values + + def _Elnchoosek(self): + # return expected value of log(nchoosek), log(n!),log(k!), log(n-k!) + # where k~Bin(n,p) + s = self.enumerate_support() + s[0] = 1 # 0! = 1 + # x is factorial matrix i.e. x[k,...] = k! + x = torch.cumsum(s.log(), dim=0) + s[0] = 0 + indices = [slice(None)] * x.dim() + indices[0] = torch.arange(x.size(0) - 1, -1, -1, + dtype=torch.long, device=x.device) + # x[tuple(indices)] is x reversed on first axis + lnchoosek = x[-1] - x - x[tuple(indices)] + elognfac = x[-1] + elogkfac = ((lnchoosek + s * self.logits + self.total_count * self._log1pmprobs()).exp() * + x).sum(dim=0) + elognmkfac = ((lnchoosek + s * self.logits + self.total_count * self._log1pmprobs()).exp() * + x[tuple(indices)]).sum(dim=0) + return elognfac - elogkfac - elognmkfac, (elognfac, elogkfac, elognmkfac) + + def entropy(self): + elnchoosek, _ = self._Elnchoosek() + return - elnchoosek - self.mean * self.logits - self.total_count * self._log1pmprobs() diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 2ae67fc28ccbc..ad420f31819aa 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -199,12 +199,29 @@ def _kl_binomial_binomial(p, q): # kullback-leibler-divergence-for-binomial-distributions-p-and-q if (p.total_count < q.total_count).any(): raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented') - kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()) + kl = p.total_count * (p.probs * (p.logits - q.logits) + p._log1pmprobs() - q._log1pmprobs()) inf_idxs = p.total_count > q.total_count kl[inf_idxs] = _infinite_like(kl[inf_idxs]) return kl +@register_kl(Binomial, Poisson) +def _kl_binomial_poisson(p, q): + _, (e1, _, e3) = p._Elnchoosek() + return (e1 - e3 + + p.mean * (p.logits - q.rate.log()) + + p.total_count * p._log1pmprobs() + + q.rate) + + +@register_kl(Binomial, Geometric) +def _kl_binomial_geometric(p, q): + elnchoosek, _ = p._Elnchoosek() + return (elnchoosek + + (p.logits - (-q.probs).log1p()) * p.mean + + p.total_count * p._log1pmprobs() - q.probs.log()) + + @register_kl(Categorical, Categorical) def _kl_categorical_categorical(p, q): t = p.probs * (p.logits - q.logits) @@ -273,6 +290,11 @@ def _kl_geometric_geometric(p, q): return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits +@register_kl(Geometric, Binomial) +def _kl_geometric_infinity(p, q): + return _infinite_like(p.probs) + + @register_kl(HalfNormal, HalfNormal) def _kl_halfnormal_halfnormal(p, q): return _kl_normal_normal(p.base_dist, q.base_dist) From ac2e622e386ec042e42a3fd37fcb4cd36ed7784c Mon Sep 17 00:00:00 2001 From: Alican Bozkurt Date: Tue, 26 Jun 2018 17:39:35 -0700 Subject: [PATCH 2/2] address review comments --- test/test_distributions.py | 1 + torch/distributions/binomial.py | 60 ++++++++++++++++++--------------- torch/distributions/kl.py | 48 +++++++++++++------------- 3 files changed, 57 insertions(+), 52 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 630fe01cce0dd..28a6bbed92152 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2851,6 +2851,7 @@ def __init__(self, probs): (Exponential(1), Beta(2, 3)), (Exponential(1), Pareto(2, 3)), (Exponential(1), Uniform(-2, 3)), + (Geometric(0.3), Binomial(10, 0.2)), (Gamma(1, 2), Beta(3, 4)), (Gamma(1, 2), Pareto(3, 4)), (Gamma(1, 2), Uniform(-3, 4)), diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 1fcbe6a3e759a..ad228a8f54257 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -5,6 +5,35 @@ from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs +def _log1pmtensor(logit_tensor): + """ + Calculates (-tensor).log1p() using logit_tensor = tensor.log() - (-tensor).log() + Useful for distributions with extreme probs. + Note that: (-probs).log1p() = max_val - (logits + 2 * max_val).exp().log1p() + """ + max_val = (-logit_tensor).clamp(min=0.0) + return max_val - torch.log1p((logit_tensor + 2 * max_val).exp()) + + +def _Elnchoosek(p): + """ + Returns expected value of log(nchoosek), log(n!), log(k!), log(n-k!); + where k~p, p is a Binomial distribution + """ + s = p.enumerate_support() + s[0] = 1 # 0! = 1 + # x is log factorial matrix i.e. x[k,...] = log(k!) + x = torch.cumsum(s.log(), dim=0) + s[0] = 0 + lnchoosek = x[-1] - x - x.flip(0) + elognfac = x[-1] + elogkfac = ((lnchoosek + s * p.logits + p.total_count * _log1pmtensor(p.logits)).exp() * + x).sum(dim=0) + elognmkfac = ((lnchoosek + s * p.logits + p.total_count * _log1pmtensor(p.logits)).exp() * + x.flip(0)).sum(dim=0) + return elognfac - elogkfac - elognmkfac, (elognfac, elogkfac, elognmkfac) + + class Binomial(Distribution): r""" Creates a Binomial distribution parameterized by `total_count` and @@ -77,11 +106,6 @@ def probs(self): def param_shape(self): return self._param.size() - def _log1pmprobs(self): - # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp())) - max_val = (-self.logits).clamp(min=0.0) - return max_val - torch.log1p((self.logits + 2 * max_val).exp()) - def sample(self, sample_shape=torch.Size()): with torch.no_grad(): max_count = max(int(self.total_count.max()), 1) @@ -100,7 +124,7 @@ def log_prob(self, value): log_factorial_k = torch.lgamma(value + 1) log_factorial_nmk = torch.lgamma(self.total_count - value + 1) return (log_factorial_n - log_factorial_k - log_factorial_nmk + - value * self.logits + self.total_count * self._log1pmprobs()) + value * self.logits + self.total_count * _log1pmtensor(self.logits)) def enumerate_support(self): total_count = int(self.total_count.max()) @@ -112,26 +136,6 @@ def enumerate_support(self): values = values.expand((-1,) + self._batch_shape) return values - def _Elnchoosek(self): - # return expected value of log(nchoosek), log(n!),log(k!), log(n-k!) - # where k~Bin(n,p) - s = self.enumerate_support() - s[0] = 1 # 0! = 1 - # x is factorial matrix i.e. x[k,...] = k! - x = torch.cumsum(s.log(), dim=0) - s[0] = 0 - indices = [slice(None)] * x.dim() - indices[0] = torch.arange(x.size(0) - 1, -1, -1, - dtype=torch.long, device=x.device) - # x[tuple(indices)] is x reversed on first axis - lnchoosek = x[-1] - x - x[tuple(indices)] - elognfac = x[-1] - elogkfac = ((lnchoosek + s * self.logits + self.total_count * self._log1pmprobs()).exp() * - x).sum(dim=0) - elognmkfac = ((lnchoosek + s * self.logits + self.total_count * self._log1pmprobs()).exp() * - x[tuple(indices)]).sum(dim=0) - return elognfac - elogkfac - elognmkfac, (elognfac, elogkfac, elognmkfac) - def entropy(self): - elnchoosek, _ = self._Elnchoosek() - return - elnchoosek - self.mean * self.logits - self.total_count * self._log1pmprobs() + elnchoosek, _ = _Elnchoosek(self) + return - elnchoosek - self.mean * self.logits - self.total_count * _log1pmtensor(self.logits) diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index ad420f31819aa..50f8c6cae834c 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -6,7 +6,7 @@ from .bernoulli import Bernoulli from .beta import Beta -from .binomial import Binomial +from .binomial import Binomial, _log1pmtensor, _Elnchoosek from .categorical import Categorical from .dirichlet import Dirichlet from .distribution import Distribution @@ -199,29 +199,12 @@ def _kl_binomial_binomial(p, q): # kullback-leibler-divergence-for-binomial-distributions-p-and-q if (p.total_count < q.total_count).any(): raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented') - kl = p.total_count * (p.probs * (p.logits - q.logits) + p._log1pmprobs() - q._log1pmprobs()) + kl = p.total_count * (p.probs * (p.logits - q.logits) + _log1pmtensor(p.logits) - _log1pmtensor(q.logits)) inf_idxs = p.total_count > q.total_count kl[inf_idxs] = _infinite_like(kl[inf_idxs]) return kl -@register_kl(Binomial, Poisson) -def _kl_binomial_poisson(p, q): - _, (e1, _, e3) = p._Elnchoosek() - return (e1 - e3 + - p.mean * (p.logits - q.rate.log()) + - p.total_count * p._log1pmprobs() + - q.rate) - - -@register_kl(Binomial, Geometric) -def _kl_binomial_geometric(p, q): - elnchoosek, _ = p._Elnchoosek() - return (elnchoosek + - (p.logits - (-q.probs).log1p()) * p.mean + - p.total_count * p._log1pmprobs() - q.probs.log()) - - @register_kl(Categorical, Categorical) def _kl_categorical_categorical(p, q): t = p.probs * (p.logits - q.logits) @@ -290,11 +273,6 @@ def _kl_geometric_geometric(p, q): return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits -@register_kl(Geometric, Binomial) -def _kl_geometric_infinity(p, q): - return _infinite_like(p.probs) - - @register_kl(HalfNormal, HalfNormal) def _kl_halfnormal_halfnormal(p, q): return _kl_normal_normal(p.base_dist, q.base_dist) @@ -418,6 +396,23 @@ def _kl_beta_uniform(p, q): return result +@register_kl(Binomial, Poisson) +def _kl_binomial_poisson(p, q): + _, (e1, _, e3) = _Elnchoosek(p) + return (e1 - e3 + + p.mean * (p.logits - q.rate.log()) + + p.total_count * _log1pmtensor(p.logits) + + q.rate) + + +@register_kl(Binomial, Geometric) +def _kl_binomial_geometric(p, q): + elnchoosek, _ = _Elnchoosek(p) + return (elnchoosek + + (p.logits - (-q.probs).log1p()) * p.mean + + p.total_count * _log1pmtensor(p.logits) - q.probs.log()) + + @register_kl(Exponential, Beta) @register_kl(Exponential, Pareto) @register_kl(Exponential, Uniform) @@ -490,6 +485,11 @@ def _kl_gamma_normal(p, q): return t1 + (p.concentration - 1) * p.concentration.digamma() + (t2 - t3 + t4) / var_normal +@register_kl(Geometric, Binomial) +def _kl_geometric_infinity(p, q): + return _infinite_like(p.probs) + + @register_kl(Gumbel, Beta) @register_kl(Gumbel, Exponential) @register_kl(Gumbel, Gamma)