Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import gammaln, digamma
from collections import Counter

Text Example

tokens = """
the cat sat on the mat . the cat ate the fish . the dog sat on the mat .
the dog chased the cat . the cat ran from the dog . you see the cat .
you see the dog . you see the mat . i see the cat . i see the dog .
the cat sat on the fish . the dog ate the fish . you see the fish .
""".split()
vocab = sorted(set(tokens))
w2i = {w: idx for idx, w in enumerate(vocab)}
i2w = {idx: w for w, idx in w2i.items()}
k = len(vocab)
n = len(tokens)
print(f"Tokens n = {n},  Vocabulary size k = {k}\n")
print("Vocabulary:", vocab)
print("tokens:", tokens)
Tokens n = 76,  Vocabulary size k = 15

Vocabulary: ['.', 'ate', 'cat', 'chased', 'dog', 'fish', 'from', 'i', 'mat', 'on', 'ran', 'sat', 'see', 'the', 'you']
tokens: ['the', 'cat', 'sat', 'on', 'the', 'mat', '.', 'the', 'cat', 'ate', 'the', 'fish', '.', 'the', 'dog', 'sat', 'on', 'the', 'mat', '.', 'the', 'dog', 'chased', 'the', 'cat', '.', 'the', 'cat', 'ran', 'from', 'the', 'dog', '.', 'you', 'see', 'the', 'cat', '.', 'you', 'see', 'the', 'dog', '.', 'you', 'see', 'the', 'mat', '.', 'i', 'see', 'the', 'cat', '.', 'i', 'see', 'the', 'dog', '.', 'the', 'cat', 'sat', 'on', 'the', 'fish', '.', 'the', 'dog', 'ate', 'the', 'fish', '.', 'you', 'see', 'the', 'fish', '.']

i.i.d Model

The observed counts for each of the kk words are given below.

counts = Counter(tokens)
print("Word counts:")
for w in vocab:
    print(f"  {w:>8s}: {counts[w]}")
Word counts:
         .: 13
       ate: 2
       cat: 7
    chased: 1
       dog: 6
      fish: 4
      from: 1
         i: 2
       mat: 3
        on: 3
       ran: 1
       sat: 3
       see: 6
       the: 20
       you: 4

We can use the Dirichlet-Multinomial inference here with prior Dirichlet (0,,0)(0, \dots, 0). This leads to the posterior Dirichlet (x1,,xk)(x_1, \dots, x_k) where xix_i is the observed count for the ii-th word. We can use this fitted model to do sentence generation. We first draw (p1,,pk)(p_1, \dots, p_k) from Dirichlet (x1,,xk)(x_1, \dots, x_k) and then draw words from sequentially from Multinomial (1;p1,,pk)(1; p_1, \dots, p_k) until we get a ‘.’ (period).

np.random.seed(42)
# Dirichlet posterior parameters
alpha = np.array([counts[w] for w in vocab], dtype=float)
print("Generating sentences from the i.i.d model posterior posterior:\n")
M = 20 #this is the number of sentences
for s in range(M):
    # Step 1: Draw a probability vector from the Dirichlet posterior
    p = np.random.dirichlet(alpha)    
    # Step 2: Generate words until '.'
    sentence = []
    while True:
        word_idx = np.random.choice(k, p=p)
        word = i2w[word_idx]
        sentence.append(word)
        if word == '.':
            break
    print(f"  Sentence {s+1}: {' '.join(sentence)}")
Generating sentences from the i.i.d model posterior posterior:

  Sentence 1: .
  Sentence 2: the .
  Sentence 3: .
  Sentence 4: mat cat .
  Sentence 5: dog the the sat .
  Sentence 6: the ran cat cat see the the the ran .
  Sentence 7: the chased .
  Sentence 8: the dog dog the the the the on mat the the the the the dog fish .
  Sentence 9: .
  Sentence 10: sat see .
  Sentence 11: the see the cat .
  Sentence 12: chased the the the the .
  Sentence 13: from the .
  Sentence 14: .
  Sentence 15: the i i cat .
  Sentence 16: ate .
  Sentence 17: fish sat .
  Sentence 18: cat the the mat mat .
  Sentence 19: cat .
  Sentence 20: sat chased sat ate i mat .

These sentences are of course not realistic, which means that our model is too silly and unrealistic for this dataset. Note that this model works directly with the individual word counts but ignores any information about how words occur in specific sequences. A slighly improved model can be obtained by working with bigram counts as opposed to individual word counts.

AR(1) model (also known as the Markov or Bigram model)

This model uses bigram counts (and not the raw data directly). Below we compute the bigram counts xjix_{j \mid i} as well as xi=j=1kxjix_i = \sum_{j=1}^k x_{j \mid i}.

# --- 2. Bigram counts: X[j, i] = x_{j|i} = # times word j follows word i ---
X = np.zeros((k, k), dtype=np.int64)
for t in range(1, n):
    i_prev = w2i[tokens[t - 1]]  # i = previous word
    j_next = w2i[tokens[t]]      # j = next word
    X[j_next, i_prev] += 1

# x_i = sum_j x_{j|i}  (times i appears as the previous word)
x_prev = X.sum(axis=0)  # shape (k,)

print(X)
print(x_prev)
[[0 0 3 0 3 4 0 0 3 0 0 0 0 0 0]
 [0 0 1 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 7 0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 6 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 4 0]
 [0 0 0 0 0 0 0 0 0 0 1 0 0 0 0]
 [2 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 3 0]
 [0 0 0 0 0 0 0 0 0 0 0 3 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 2 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 2 0 0 0 0 0 0 4]
 [6 2 0 1 0 0 1 0 0 3 0 0 6 0 0]
 [4 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[12  2  7  1  6  4  1  2  3  3  1  3  6 20  4]

The first column refers to the number of times each word appears after ‘.’. More specifically, after ‘.’, ‘i’ appears 2 times, ‘the’ appears 6 times and ‘you’ appears 4 times. And the sum of these numbers is the number of times ‘.’ appears as the previous word.

The next function calculates the log-Evidence for given values of a1,,aka_1, \dots, a_k. Note that Evidence is the probability of obtaining the bigram sequence given values of a1,,aka_1, \dots, a_k.

# --- 3. Log marginal likelihood (evidence) up to constants in a ---
# Here we parameterize a = (a_1,...,a_k) and A = sum_j a_j.
# The marginal over all rows i is:
#   prod_i [ Γ(A)/Γ(x_i + A) * prod_j Γ(x_{j|i}+a_j)/Γ(a_j) ]
# (multinomial coefficients omitted since they do not depend on a)

def log_evidence(a):
    A = a.sum()
    val = 0.0
    for i in range(k):
        if x_prev[i] == 0:
            continue
        val += gammaln(A) - gammaln(x_prev[i] + A)
        val += (gammaln(X[:, i] + a) - gammaln(a)).sum()
    return val

The next function calculates the gradient of the log-evidence using the digamma function (which is in-built in scipy, and is the derivative of the logarithm of the Gamma function). The notes has the exact formula for the gradient in terms of the digamma function.

def grad_log_evidence(a):
    A = a.sum()
    g = np.zeros(k)
    for i in range(k):
        if x_prev[i] == 0:
            continue
        # vector part: sum_i [ ψ(x_{j|i}+a_j) - ψ(a_j) ]
        g += digamma(X[:, i] + a) - digamma(a)
        # scalar part from A: sum_i [ ψ(A) - ψ(x_i + A) ] added to every component
        g += digamma(A) - digamma(x_prev[i] + A)
    return g

The code below maximizes the log-Evidence with respect to a1,,aka_1, \dots, a_k. It uses a simple gradient ascent method. It calculates the gradient of the log-Evidence at the current value of a=(a1,,ak)a = (a_1, \dots, a_k) and takes a step in that direction. A line search is done in that direction to find the best step-size.

a = 1 * np.ones(k) #initial values of a_1, \dots, a_k

print("Optimizing hyperparameters a = (a_1,...,a_k)...")
for it in range(3000):
    log_a = np.log(a)
    le = log_evidence(a)
    g = grad_log_evidence(a) * a  # chain rule: d/d(log a) = a * d/da
    step = 0.01
    for _ in range(10):
        trial = np.exp(log_a + step * g).clip(1e-10)
        if log_evidence(trial) > le:
            break
        step *= 0.5
    a = np.exp(log_a + step * g).clip(1e-10)
    if it % 100 == 0 or it == 299:
        print(f"  iter {it:3d}: log_ev = {le:.2f}, A = {a.sum():.3f}")

A = a.sum()
Optimizing hyperparameters a = (a_1,...,a_k)...
  iter   0: log_ev = -165.87, A = 14.829
  iter 100: log_ev = -140.73, A = 7.111
  iter 200: log_ev = -132.25, A = 3.348
  iter 299: log_ev = -128.35, A = 2.041
  iter 300: log_ev = -128.33, A = 2.033
  iter 400: log_ev = -126.79, A = 1.500
  iter 500: log_ev = -126.19, A = 1.247
  iter 600: log_ev = -125.95, A = 1.111
  iter 700: log_ev = -125.85, A = 1.034
  iter 800: log_ev = -125.81, A = 0.987
  iter 900: log_ev = -125.80, A = 0.958
  iter 1000: log_ev = -125.79, A = 0.940
  iter 1100: log_ev = -125.79, A = 0.928
  iter 1200: log_ev = -125.79, A = 0.920
  iter 1300: log_ev = -125.79, A = 0.915
  iter 1400: log_ev = -125.79, A = 0.912
  iter 1500: log_ev = -125.79, A = 0.910
  iter 1600: log_ev = -125.79, A = 0.909
  iter 1700: log_ev = -125.79, A = 0.908
  iter 1800: log_ev = -125.79, A = 0.907
  iter 1900: log_ev = -125.79, A = 0.907
  iter 2000: log_ev = -125.79, A = 0.907
  iter 2100: log_ev = -125.79, A = 0.906
  iter 2200: log_ev = -125.79, A = 0.906
  iter 2300: log_ev = -125.79, A = 0.906
  iter 2400: log_ev = -125.79, A = 0.906
  iter 2500: log_ev = -125.79, A = 0.906
  iter 2600: log_ev = -125.79, A = 0.906
  iter 2700: log_ev = -125.79, A = 0.906
  iter 2800: log_ev = -125.79, A = 0.906
  iter 2900: log_ev = -125.79, A = 0.906

Below are the values of the estimated a1,,aka_1, \dots, a_k (and their sum A=a1++akA = a_1 + \dots + a_k).

print(f"\nOptimal A = {A:.3f}\n")
print(f"{'word':>8s}  {'a_j':>10s}")
print("-" * 30)
for j in range(k):
    print(f"{vocab[j]:>8s}  {a[j]:10.4f}")

Optimal A = 0.906

    word         a_j
------------------------------
       .      0.1523
     ate      0.0627
     cat      0.0339
  chased      0.0313
     dog      0.0337
    fish      0.0332
    from      0.0313
       i      0.0323
     mat      0.0328
      on      0.0328
     ran      0.0313
     sat      0.0646
     see      0.0684
     the      0.2321
     you      0.0332

Having obtained a1,,aka_1, \dots, a_k, we compute below the posterior mean estimates of pjip_{j \mid i}.

P_hat = np.zeros((k, k))
for i in range(k):
    P_hat[:, i] = (X[:, i] + a) / (x_prev[i] + A)
print(P_hat)
[[0.01179979 0.05240418 0.39871914 0.07989796 0.45645401 0.84636298
  0.07989796 0.05240418 0.80702976 0.03898798 0.07989796 0.03898798
  0.0220515  0.00728443 0.03104104]
 [0.0048559  0.02156559 0.13441252 0.03287995 0.15387557 0.01277414
  0.03287995 0.02156559 0.0160445  0.0160445  0.03287995 0.0160445
  0.00907473 0.00299772 0.01277414]
 [0.00262547 0.01166    0.00428589 0.0177774  0.00490649 0.00690667
  0.0177774  0.01166    0.00867488 0.00867488 0.0177774  0.00867488
  0.00490649 0.3364523  0.00690667]
 [0.00242796 0.01078282 0.00396346 0.01644002 0.14933821 0.00638709
  0.01644002 0.01078282 0.00802227 0.00802227 0.01644002 0.00802227
  0.00453738 0.00149886 0.00638709]
 [0.00261083 0.01159499 0.00426199 0.01767828 0.00487913 0.00686816
  0.01767828 0.01159499 0.00862651 0.00862651 0.01767828 0.00862651
  0.00487913 0.28861019 0.00686816]
 [0.00257218 0.01142334 0.0041989  0.01741658 0.0048069  0.00676649
  0.01741658 0.01142334 0.00849881 0.00849881 0.01741658 0.00849881
  0.0048069  0.19292019 0.00676649]
 [0.00242796 0.01078282 0.00396346 0.01644002 0.00453738 0.00638709
  0.01644002 0.01078282 0.00802227 0.00802227 0.54108867 0.00802227
  0.00453738 0.00149886 0.00638709]
 [0.15747019 0.01112037 0.00408754 0.01695466 0.00467942 0.00658703
  0.01695466 0.01112037 0.0082734  0.0082734  0.01695466 0.0082734
  0.00467942 0.00154578 0.00658703]
 [0.00254437 0.01129983 0.0041535  0.01722827 0.00475493 0.00669333
  0.01722827 0.01129983 0.00840692 0.00840692 0.01722827 0.00840692
  0.00475493 0.14506995 0.00669333]
 [0.00254437 0.01129983 0.0041535  0.01722827 0.00475493 0.00669333
  0.01722827 0.01129983 0.00840692 0.00840692 0.01722827 0.7764487
  0.00475493 0.00157073 0.00669333]
 [0.00242796 0.01078282 0.13044908 0.01644002 0.00453738 0.00638709
  0.01644002 0.01078282 0.00802227 0.00802227 0.01644002 0.00802227
  0.00453738 0.00149886 0.00638709]
 [0.00500317 0.02221962 0.26113854 0.03387711 0.15415078 0.01316155
  0.03387711 0.02221962 0.01653109 0.01653109 0.03387711 0.01653109
  0.00934994 0.00308863 0.01316155]
 [0.00530131 0.0235437  0.00865401 0.03589588 0.00990711 0.01394586
  0.03589588 0.71176607 0.01751619 0.01751619 0.03589588 0.01751619
  0.00990711 0.00327269 0.8292678 ]
 [0.48288391 0.76809673 0.02935957 0.64642899 0.03361086 0.0473127
  0.64642899 0.07987436 0.05942541 0.8274672  0.12178034 0.05942541
  0.90241588 0.01110291 0.0473127 ]
 [0.31250464 0.01142334 0.0041989  0.01741658 0.0048069  0.00676649
  0.01741658 0.01142334 0.00849881 0.00849881 0.01741658 0.00849881
  0.0048069  0.0015879  0.00676649]]
print(f"\n{'='*45}")
print("NEXT-WORD PREDICTIONS (posterior mean)")
print(f"{'='*45}")

#The following code predicts the top five most probable next words for a given word
def predict_next(prev_word, top_k=5):
    i = w2i[prev_word]
    probs = P_hat[:, i]                 # probs over j given i
    ranked = np.argsort(probs)[::-1]
    return [(i2w[j], probs[j]) for j in ranked[:top_k]]


for ctx in ['the', 'cat', 'you', 'dog', '.', 'on']:
    i = w2i[ctx]
    lam = A / (x_prev[i] + A)  # shrinkage weight on the prior mean m
    print(f"\nAfter '{ctx}'  (x_i={int(x_prev[i])}, λ={lam:.3f}):")
    for word, prob in predict_next(ctx, top_k=5):
        bar = '█' * int(prob * 40)
        print(f"  {word:>8s}  {prob:.3f}  {bar}")

=============================================
NEXT-WORD PREDICTIONS (posterior mean)
=============================================

After 'the'  (x_i=20, λ=0.043):
       cat  0.336  █████████████
       dog  0.289  ███████████
      fish  0.193  ███████
       mat  0.145  █████
       the  0.011  

After 'cat'  (x_i=7, λ=0.115):
         .  0.399  ███████████████
       sat  0.261  ██████████
       ate  0.134  █████
       ran  0.130  █████
       the  0.029  █

After 'you'  (x_i=4, λ=0.185):
       see  0.829  █████████████████████████████████
       the  0.047  █
         .  0.031  █
       sat  0.013  
       ate  0.013  

After 'dog'  (x_i=6, λ=0.131):
         .  0.456  ██████████████████
       sat  0.154  ██████
       ate  0.154  ██████
    chased  0.149  █████
       the  0.034  █

After '.'  (x_i=12, λ=0.070):
       the  0.483  ███████████████████
       you  0.313  ████████████
         i  0.157  ██████
         .  0.012  
       see  0.005  

After 'on'  (x_i=3, λ=0.232):
       the  0.827  █████████████████████████████████
         .  0.039  █
       see  0.018  
       sat  0.017  
       ate  0.016  

The code below generates sentences using this model. It takes as input a starting word (e.g., ‘.’) and then starts generating from there until the period ‘.’ appears. Sentence generation can be done in two different ways here. The first way uses the Bayesian estimates of pjip_{j \mid i} and generates words from multinomials. The second way also generates pjip_{j \mid i} from their posterior first before generating from multinomials.

def generate(start_word, length=12, seed=45):
    rng = np.random.default_rng(seed)
    words = [start_word]
    for _ in range(length):
        i = w2i[words[-1]]
        j = rng.choice(k, p=P_hat[:, i])
        words.append(i2w[j])
        if words[-1] == '.':
            break
    return ' '.join(words)

print(f"\n{'='*45}")
print("GENERATED SENTENCES")
print(f"{'='*45}\n")
for s in range(20):
    print("  " + generate('.', seed=45 + s))

=============================================
GENERATED SENTENCES
=============================================

  . the dog chased the dog i see the dog chased the mat
  . you i see the mat .
  . you see the cat the cat .
  . the dog ate the fish cat ran from the cat .
  . the dog .
  . you see the mat .
  . you see the cat sat on the fish .
  . the dog ate .
  . .
  . the dog .
  . you see the cat ate the fish the cat ate the dog
  . you .
  . the dog sat on the cat ate the cat .
  . the mat .
  . the fish .
  . the cat .
  . the fish .
  . the fish .
  . the the cat ate see the dog .
  . you see on the mat .

You can ignore the starting period while interpreting each of these sentences.

In the above, we fixed the Bayesian estimate of P[i,j]P[i, j] while generating the sentences. We can also draw PP from the correct Bayesian posterior, before generating each sentence. This is done below.

def sample_P_posterior(rng):
    """
    Sample a full transition matrix P from the posterior.
    Each column i is p_{·|i} ~ Dirichlet(X[:, i] + a).
    """
    P = np.zeros((k, k))
    for i in range(k):
        alpha = X[:, i] + a
        P[:, i] = rng.dirichlet(alpha)
    return P

def generate_from_P(start_word, P, length=12, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    words = [start_word]
    for _ in range(length):
        i = w2i[words[-1]]
        j = rng.choice(k, p=P[:, i])
        words.append(i2w[j])
        if words[-1] == '.':
            break
    return ' '.join(words)

def generate_posterior_sampled(start_word, length=12, seed=45):
    """
    For THIS sentence: sample P ~ posterior, then generate using that P.
    """
    rng = np.random.default_rng(seed)
    P = sample_P_posterior(rng)
    return generate_from_P(start_word, P, length=length, rng=rng)

# --- 9. Generate sentences ---
print(f"\n{'='*45}")
print("GENERATED SENTENCES (posterior-sampled P per sentence)")
print(f"{'='*45}\n")
for s in range(20):
    print(" " + generate_posterior_sampled('.', seed=45 + s))

=============================================
GENERATED SENTENCES (posterior-sampled P per sentence)
=============================================

 . the cat .
 . the cat see the dog .
 . i see the mat .
 . the cat .
 . the fish .
 . you see the mat .
 . the cat sat fish .
 . you see the fish .
 . the dog .
 . the cat sat on dog .
 . the cat .
 . the mat .
 . the fish .
 . you see the dog .
 . the mat .
 . you see the cat .
 . on the cat sat on the dog cat .
 . i see the fish .
 . you see the see the fish .
 . the dog .

Overall these sentences are more realistic compared to the simple i.i.d (bag of words) model. Obviously, more sophisticated language models will yield much more realistic sentences.