-
Notifications
You must be signed in to change notification settings - Fork 1
add binomial entropy and kl #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
torch/distributions/binomial.py
Outdated
@@ -77,6 +77,11 @@ def probs(self): | |||
def param_shape(self): | |||
return self._param.size() | |||
|
|||
def _log1pmprobs(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it is a function for internal use, I think this can be moved to the top, like in MVN. Something like:
def _log1pmtensor(tensor):
# Do the same thing
Uses of the function in kl.py
can be done via importing this function along with Binomial
.
torch/distributions/binomial.py
Outdated
@@ -109,3 +111,27 @@ def enumerate_support(self): | |||
values = values.view((-1,) + (1,) * len(self._batch_shape)) | |||
values = values.expand((-1,) + self._batch_shape) | |||
return values | |||
|
|||
def _Elnchoosek(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same idea here.
torch/distributions/binomial.py
Outdated
s = self.enumerate_support() | ||
s[0] = 1 # 0! = 1 | ||
# x is factorial matrix i.e. x[k,...] = k! | ||
x = torch.cumsum(s.log(), dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x
is the log of factorial matrix right?
torch/distributions/binomial.py
Outdated
indices[0] = torch.arange(x.size(0) - 1, -1, -1, | ||
dtype=torch.long, device=x.device) | ||
# x[tuple(indices)] is x reversed on first axis | ||
lnchoosek = x[-1] - x - x[tuple(indices)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think x.flip(dim=0)
will exhibit same behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weird, I tried using flip
and it didn't work before- maybe I messed with arguments...
torch/distributions/binomial.py
Outdated
elognfac = x[-1] | ||
elogkfac = ((lnchoosek + s * self.logits + self.total_count * self._log1pmprobs()).exp() * | ||
x).sum(dim=0) | ||
elognmkfac = ((lnchoosek + s * self.logits + self.total_count * self._log1pmprobs()).exp() * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E[log(n-k)!] = E[log k!] but for Bin(n, (1 - p)). Can we use this fact here?
torch/distributions/kl.py
Outdated
inf_idxs = p.total_count > q.total_count | ||
kl[inf_idxs] = _infinite_like(kl[inf_idxs]) | ||
return kl | ||
|
||
|
||
@register_kl(Binomial, Poisson) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heterogeneous combinations were placed below. This section was for homogeneous combinations.
torch/distributions/kl.py
Outdated
q.rate) | ||
|
||
|
||
@register_kl(Binomial, Geometric) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above comment.
torch/distributions/kl.py
Outdated
@@ -273,6 +290,11 @@ def _kl_geometric_geometric(p, q): | |||
return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits | |||
|
|||
|
|||
@register_kl(Geometric, Binomial) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments have been given. Please check them.
Could you check if the KL test passes with lower tolerance, and how much time it takes in the default tolerance setting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these!
@vishwakftw thanks for the comments! One thing I want us to work out before wrapping this up is an approximation to E[logk!] for large n, I tried Stirling's but couldn't come up with a closed form. Any ideas? |
I think we have to make use of Stirling's inequality and the Taylor series to compute this. I guess the reason you are unable to come up with a closed form is because of the log (k) term. I tried using them, and got about 0.5% relative error. This might help after the expansion of log k! <= 1 + klog k + 0.5 log k - k <source: wikipedia: https://en.wikipedia.org/wiki/Taylor_expansions_for_the_moments_of_functions_of_random_variables> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!! @fritzo what do you think?
Also, are you going to try the large |
@vishwakftw btw I tried it with 0.01 precision as well. 2 things on my wishlist:
|
@alicanb I have a closed form solution for E[log x!], E[log (n - x)!] and E[log n!] (this is simply log n!) for large |
Great, have you experimented with any large |
This is the gist for the approximations. I ran some tests: n = {10, 20, 50, 75, 100} and p = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9} |
btw |
This is a larger PR than I intended but basically it adds binomial entropy and binomial-poisson and binomial-geometric KL with some helper functions:
binomial._log1pmprobs
: I used this a lot so I made it a separate function. it calculates(-probs).log1p()
safely.binomial._Elnchoosek()
: for x~Bin(n, p), this calculates E[log(nchoosek)], E[log(n!)], E[log(x!)], E[log((n-x)!)]