multinomial pmf in python scipy/numpy

2020-07-03 04:11发布

问题:

Is there a built-in function in scipy/numpy for getting the PMF of a Multinomial? I'm not sure if binom generalizes in the correct way, e.g.

# Attempt to define multinomial with n = 10, p = [0.1, 0.1, 0.8]
rv = scipy.stats.binom(10, [0.1, 0.1, 0.8])
# Score the outcome 4, 4, 2
rv.pmf([4, 4, 2])

What is the correct way to do this? thanks.

回答1:

There's no built-in function that I know of, and the binomial probabilities do not generalize (you need to normalise over a different set of possible outcomes, since the sum of all the counts must be n which won't be taken care of by independent binomials). However, it's fairly straightforward to implement yourself, for example:

import math

class Multinomial(object):
  def __init__(self, params):
    self._params = params

  def pmf(self, counts):
    if not(len(counts)==len(self._params)):
      raise ValueError("Dimensionality of count vector is incorrect")

    prob = 1.
    for i,c in enumerate(counts):
      prob *= self._params[i]**counts[i]

    return prob * math.exp(self._log_multinomial_coeff(counts))

  def log_pmf(self,counts):
    if not(len(counts)==len(self._params)):
      raise ValueError("Dimensionality of count vector is incorrect")

    prob = 0.
    for i,c in enumerate(counts):
      prob += counts[i]*math.log(self._params[i])

    return prob + self._log_multinomial_coeff(counts)

  def _log_multinomial_coeff(self, counts):
    return self._log_factorial(sum(counts)) - sum(self._log_factorial(c)
                                                    for c in counts)

  def _log_factorial(self, num):
    if not round(num)==num and num > 0:
      raise ValueError("Can only compute the factorial of positive ints")
    return sum(math.log(n) for n in range(1,num+1))

m = Multinomial([0.1, 0.1, 0.8])
print m.pmf([4,4,2])

>>2.016e-05

My implementation of the multinomial coefficient is somewhat naive, and works in log space to prevent overflow. Also be aware that n is superfluous as a parameter, since it's given by the sum of the counts (and the same parameter set works for any n). Furthermore, since this will quickly underflow for moderate n or large dimensionality, you're better working in log space (logPMF provided here too!)