Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

import numpy as np
import matplotlib.pyplot as plt

Lecture 12: Dirichlet-Multinomial Inference

Simulations from Beta and Dirichlet distributions

#Simulate from Beta(a, b)
a = .000002
#a = 0.5
b = .000002
#b = 0.5
beta_samples = np.random.beta(a, b, size=1000)
print(beta_samples)
print(np.mean(beta_samples))
[0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1.
 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 1. 1. 1. 0. 0. 1.
 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1. 1. 0. 0. 1. 1. 0. 1. 0. 0. 1.
 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0.
 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 0. 0. 1. 0. 0. 0. 1. 1. 0. 1. 1.
 0. 1. 1. 0. 1. 1. 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 0.
 0. 0. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 1. 1. 1. 0. 0. 1. 1. 0. 1. 0. 1.
 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 0. 0. 0. 1.
 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 0. 1. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0.
 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 1. 1. 1. 0. 1. 1. 1. 0.
 0. 0. 1. 0. 1. 1. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 1. 1. 1. 0.
 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1. 0. 1. 1. 0. 1. 1. 1. 1. 0.
 0. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 1. 0. 1. 1.
 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 0. 0. 1. 1. 1. 0. 1. 0. 1. 0. 1. 0. 0.
 0. 1. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1.
 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 1. 0. 1. 0.
 0. 0. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1.
 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 1. 0.
 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 1.
 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 1. 0.
 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 0. 1. 1. 1. 1. 0. 1. 1. 0. 0.
 0. 0. 1. 1. 1. 1. 0. 1. 1. 1. 1. 0. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0.
 1. 1. 0. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 1. 0.
 1. 0. 1. 1. 1. 1. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1. 0.
 1. 1. 0. 0. 0. 1. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 0. 0. 1. 1. 1. 1. 1.
 0. 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 1. 0. 0. 0. 1. 1.
 0. 0. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 1.
 0. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0.
 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 0. 0. 1. 1. 1. 1. 0.
 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 1.
 0. 1. 0. 1. 0. 1. 0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 1. 1. 0. 0. 1. 0. 0.
 0. 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0.
 1. 0. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 1.
 0. 1. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0.
 1. 0. 1. 0. 0. 1. 1. 1. 0. 0. 0. 1. 1. 1. 0. 1. 1. 0. 0. 1. 0. 1. 0. 0.
 0. 1. 1. 0. 0. 1. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0.
 0. 0. 0. 0. 1. 0. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 0. 0.
 1. 1. 1. 1. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 0. 1. 1. 1. 0.
 0. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1.
 0. 0. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 0. 1. 0. 1. 0.
 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 1. 1. 0. 0. 1. 1.]
0.526

By setting a=b=ϵa = b = \epsilon for some small ϵ\epsilon in the above code, we can check that samples from Beta(a,b)\text{Beta}(a, b) are 0 or 1 with equal frequency (up to Monte Carlo error).

By setting a=ϵa = \epsilon and bb to be a fixed positive constant (e.g., b=0.5b = 0.5), we get all samples equal to 0.

Analogously, by setting aa to be a fixed positive constant (e.g., a=0a = 0) and b=ϵb = \epsilon, we get all samples equal to 1.

Below we obtain samples from the Dirichlet distribution with k=4k = 4.

#Samples from Dirichlet distribution
k = 4
#alpha = np.array([0.02, 0.02, 0.02, 0.02])
#alpha = np.array([0.5, 0.5, 0.5, 0.0002])
alpha = np.array([3, 1, 0.000002, 0.000002])
dirichlet_samples = np.random.dirichlet(alpha, size=1000)
print(dirichlet_samples)
print(np.mean(dirichlet_samples, axis=0))
[[0.86443728 0.13556272 0.         0.        ]
 [0.94432839 0.05567161 0.         0.        ]
 [0.93369658 0.06630342 0.         0.        ]
 ...
 [0.77504867 0.22495133 0.         0.        ]
 [0.94097736 0.05902264 0.         0.        ]
 [0.84199999 0.15800001 0.         0.        ]]
[7.50627682e-001 2.49372318e-001 0.00000000e+000 1.46093836e-121]

If a1=a2=a3=a4=ϵa_1 = a_2 = a_3 = a_4 = \epsilon for some small ϵ\epsilon, then each Dirichlet sample will be one of (1,0,0,0)(1, 0, 0, 0), (0,1,0,0)(0, 1, 0, 0), (0,0,1,0)(0, 0, 1, 0), (0,0,0,1)(0, 0, 0, 1) with equal frequency (up to Monte Carlo error).

If a1,a2,a3a_1, a_2, a_3 are some fixed numbers (not small) and a4=ϵa_4 = \epsilon for a small ϵ\epsilon, then p4p_4 will be zero in every sample, and the other three (p1,p2,p3)(p_1, p_2, p_3) are drawn from Dirichlet(a1,a2,a3)(a_1, a_2, a_3).

If a1,a2a_1, a_2 are fixed numbers (not small) and a3=a4=ϵa_3 = a_4 = \epsilon for a small ϵ\epsilon, then p3,p4p_3, p_4 will be zero in every sample, and the other two (p1,p2)(p_1, p_2) are drawn from Dirichlet(a1,a2)(a_1, a_2).

A simple problem to illustrate Dirichlet-Multinomial Inference

Consider the following problem that we already studied previously (see e.g., Lecture 4):

Suppose a scientist makes 6 numerical measurements 26.6,38.5,34.4,34,31,23.626.6, 38.5, 34.4, 34, 31, 23.6 on an unknown real-valued physical quantity θ\theta. On the basis of these measurements, what can be inferred about θ\theta?

Previously we solved this problem using the Bayesian model:

X1,,Xnθ,σi.i.dN(θ,σ2)   and   θ,logσi.i.duniform(,).\begin{align*} X_1, \dots, X_n \mid \theta,\sigma \overset{\text{i.i.d}}{\sim} N(\theta, \sigma^2)~~ \text{ and }~~ \theta, \log \sigma \overset{\text{i.i.d}}{\sim} \text{uniform}(-\infty, \infty). \end{align*}

This gave the following posterior for θ\theta:

n(θθ^)S(θ^)/(n1)datatn1\begin{align*} \frac{\sqrt{n}(\theta - \hat{\theta})}{\sqrt{S(\hat{\theta})/(n-1)}} \mid \text{data} \sim t_{n-1} \end{align*}

where S(θ):=i=1n(xiθ)2S(\theta) := \sum_{i=1}^n (x_i - \theta)^2 and θ^=xˉ=(x1++xn)/n\hat{\theta} = \bar{x} = (x_1 + \dots + x_n)/n. This led to the 95% interval: [25.598,37.102][25.598, 37.102].

We now use the model:

X1,,XnPi.i.dP   and   θ=mean corresponding to P\begin{align*} X_1, \dots, X_n \mid P \overset{\text{i.i.d}}{\sim} P ~~ \text{ and } ~~ \theta = \text{mean corresponding to} ~ P \end{align*}

We use discretization and assume that PP is supported on a large finite set G={g1,,gk}G = \{g_1, \dots, g_k\} e.g., G={0.1,0.2,,99.9,100.0}G = \{0.1, 0.2, \dots, 99.9, 100.0\}. Let the probabilities assigned by PP to g1,,gkg_1, \dots, g_k be p1,,pkp_1, \dots, p_k. We use the noninformative prior:

(p1,,pk)Dirichlet(0,,0).\begin{align*} (p_1, \dots, p_k) \sim \text{Dirichlet}(0, \dots, 0). \end{align*}

This leads to the posterior:

PX1=x1,,Xn=xni=1nwiδ{xi} where (w1,,wn)Dirichlet(1,,1).\begin{align*} P \mid X_1 = x_1, \dots, X_n = x_n \sim \sum_{i=1}^n w_i \delta_{\{x_i\}} ~\text{where} ~ (w_1, \dots, w_n) \sim \text{Dirichlet}(1, \dots, 1). \end{align*}

The posterior of θ\theta is therefore:

θdatai=1nwixi where (w1,,wn)Dirichlet(1,,1).\begin{align*} \theta \mid \text{data} \sim \sum_{i=1}^n w_i x_i \text{ where } (w_1, \dots, w_n) \sim \text{Dirichlet}(1, \dots, 1). \end{align*}

Below we simulate from this posterior, and obtain posterior mean and 95% credible interval.

x_obs = np.array([26.6, 38.5, 34.4, 34, 31, 23.6])
n = len(x_obs)
M = 100000 #number of posterior samples
np.random.seed(42) #for reproducibility
W = np.random.dirichlet(alpha = np.ones(n), size=M) #M samples from the Dirichlet distribution
theta_samples = W @ x_obs #compute the posterior samples of theta

theta_mean = np.mean(theta_samples)
theta_ci = np.quantile(theta_samples, [0.025, 0.975])

print(f"Posterior mean of theta: {theta_mean:.2f}")
print(f"95% credible interval: [{theta_ci[0]:.2f}, {theta_ci[1]:.2f}]")
Posterior mean of theta: 31.35
95% credible interval: [27.53, 34.90]

Below is a histogram of all the posterior samples.

plt.hist(theta_samples, bins=40, density=True)
plt.xlabel(r"$\theta$")
plt.ylabel("Posterior density")
plt.title("Posterior Samples of $\\theta$")
plt.show()
<Figure size 640x480 with 1 Axes>

This method is actually the Bayesian bootstrap because it is very close to the bootstrap operationally. We shall see these connections in the next lecture.