This document exists in the repository at https://github.com/codatmo/Simple_SIR as index.Rmd
and rendered in html as https://codatmo.github.io/Simple_SIR/index.html.
Based on case study at https://mc-stan.org/users/documentation/case-studies/boarding_school_case_study.html. The case study is a detailed introduction to SIR models (Susceptible, Infectious, Resolved) and Bayesian modeling with Stan.
The data are freely available in the R package outbreaks, maintained as part of the R Epidemics Consortium.
library(cmdstanr)
library(outbreaks)
library(tidyverse)
library(stringr)
print_file <- function(file) {
cat(paste(readLines(file), "\n", sep=""), sep="")
}
head(influenza_england_1978_school)
date in_bed convalescent
1 1978-01-22 3 0
2 1978-01-23 8 0
3 1978-01-24 26 0
4 1978-01-25 76 0
5 1978-01-26 225 9
6 1978-01-27 298 17
Data tracks on a per day basis the number of boarding schools students in_bed
which are considered Infected and convalescent
which are considered Resolved in the SIR model, total N = 763. The mapping to data is The initial state is I = 1, S = 762, R=0.
The conversion to Stan input given the above data is as follows:
# time series of cases
cases <- influenza_england_1978_school$in_bed # Number of students in bed
# total count
N <- 763;
# times
n_days <- length(cases)
ts <- seq(1, n_days, by = 1)
t0 = 0
#initial conditions
i0 <- 1 # Infected
s0 <- N - i0 # Susceptible
r0 <- 0 # Resolved
y0 = c(S = s0, I = i0, R = r0)
compute_likelihood = 1
# data for Stan
data_sir <- list(n_days = n_days, i0 = i0, y0 = y0, s0 = s0, r0 = r0, t0 = t0, ts = ts,
N = N, cases = cases, compute_likelihood = compute_likelihood)
The Stan model is located at Simple_SIR/stan/sir_negbin.stan
:
functions {
real[] sir(real t, real[] y, real[] theta,
real[] x_r, int[] x_i) {
real S = y[1];
real I = y[2];
real R = y[3];
real N = x_i[1];
real beta = theta[1];
real gamma = theta[2];
real dS_dt = -beta * I * S / N;
real dI_dt = beta * I * S / N - gamma * I;
real dR_dt = gamma * I;
return {dS_dt, dI_dt, dR_dt};
}
}
data {
int<lower=1> n_days;
real y0[3];
real t0;
real ts[n_days];
int N;
int cases[n_days];
int compute_likelihood;
}
transformed data {
real x_r[0]; //need for ODE function
int x_i[1] = { N }; //need for ODE function
}
parameters {
real<lower=0> gamma;
real<lower=0> beta;
real<lower=0> phi_inv;
}
transformed parameters{
real y[n_days, 3];
real phi = 1. / phi_inv;
{
real theta[2];
theta[1] = beta;
theta[2] = gamma;
y = integrate_ode_rk45(sir, y0, t0, ts, theta, x_r, x_i);
}
}
model {
//priors
beta ~ normal(2, 1);
gamma ~ normal(0.4, 0.5);
phi_inv ~ exponential(5);
//sampling distribution
//col(matrix x, int n) - The n-th column of matrix x. Here the number of infected people
if (compute_likelihood == 1) {
cases ~ neg_binomial_2(col(to_matrix(y), 2), phi);
}
}
generated quantities {
real R0 = beta / gamma;
real recovery_time = 1 / gamma;
real pred_cases[n_days];
pred_cases = neg_binomial_2_rng(col(to_matrix(y), 2), phi);
}
model <- cmdstan_model(file.path("stan","sir_negbin.stan"))
fit_sir_negbin <- model$sample(
data = data_sir)
Running MCMC with 4 sequential chains...
Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
Chain 1 finished in 14.3 seconds.
Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
Chain 2 finished in 14.3 seconds.
Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
Chain 3 finished in 10.5 seconds.
Warning: Chain 4 finished unexpectedly!
Warning: 1 chain(s) finished unexpectedly!
The remaining chains had a mean execution time of 60.1 seconds.
Warning: The returned fit object will only read in results of successful
chains. Please use read_cmdstan_csv() to read the results of the failed chains
separately.
fit_sir_negbin$cmdstan_summary()
Inference for Stan model: sir_negbin_model
3 chains: each with iter=(1000,1000,1000); warmup=(0,0,0); thin=(1,1,1); 3000 iterations saved.
Warmup took (12, 4.2, 6.0) seconds, 22 seconds total
Sampling took (2.5, 10, 4.5) seconds, 17 seconds total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -65 5.5e-02 1.4 -68 -65 -64 659 38 1.0
accept_stat__ 0.92 2.0e-03 1.0e-01 0.73 0.96 1.0 2.7e+03 1.6e+02 1.0e+00
stepsize__ 0.60 5.1e-03 6.3e-03 0.60 0.60 0.61 1.5e+00 8.7e-02 3.3e+12
treedepth__ 2.4 1.1e-02 5.7e-01 2.0 2.0 3.0 2.8e+03 1.6e+02 1.0e+00
n_leapfrog__ 5.4 3.9e-02 2.0e+00 3.0 7.0 7.0 2.7e+03 1.6e+02 1.0e+00
divergent__ 0.00 nan 0.0e+00 0.00 0.00 0.00 nan nan nan
energy__ 67 6.9e-02 1.9e+00 65 67 71 7.3e+02 4.3e+01 1.0e+00
gamma 0.54 1.1e-03 0.045 0.47 0.54 0.62 1670 97 1.0
beta 1.7 2.1e-03 0.056 1.7 1.7 1.8 726 42 1.0
phi_inv 0.14 2.5e-03 0.080 0.050 0.12 0.28 1040 60 1.00
y[1,1] 759 9.9e-03 0.24 758 759 759 580 34 1.0
y[1,2] 3.3 9.3e-03 0.24 3.0 3.3 3.7 642 37 1.0
y[1,3] 1.0 1.7e-03 0.070 0.93 1.0 1.2 1680 98 1.0
y[2,1] 748 7.9e-02 1.8 745 748 750 538 31 1.0
y[2,2] 11 6.7e-02 1.6 8.8 11 13 572 33 1.0
y[2,3] 4.4 1.3e-02 0.37 3.9 4.4 5.0 771 45 1.0
y[3,1] 714 4.0e-01 9.2 701 715 725 517 30 1.0
y[3,2] 34 3.2e-01 7.4 25 33 45 534 31 1.0
y[3,3] 15 8.4e-02 2.0 13 15 18 552 32 1.0
y[4,1] 625 1.2e+00 29 580 629 660 562 33 1.0
y[4,2] 91 9.0e-01 22 65 88 125 594 35 1.0
y[4,3] 47 3.4e-01 7.9 37 46 59 542 32 1.0
y[5,1] 461 1.9e+00 51 377 465 532 730 42 1.0
y[5,2] 182 1.1e+00 33 135 180 237 859 50 1.0
y[5,3] 120 7.8e-01 20 93 117 151 630 37 1.0
y[6,1] 283 1.7e+00 52 200 284 363 977 57 1.0
y[6,2] 243 6.5e-01 26 200 243 285 1622 94 1.0
y[6,3] 237 1.1e+00 30 193 235 286 767 45 1.0
y[7,1] 165 1.1e+00 38 107 163 227 1282 75 1.0
y[7,2] 231 4.0e-01 17 202 231 258 1855 108 1.0
y[7,3] 367 1.1e+00 32 318 366 419 900 52 1.0
y[8,1] 104 7.0e-01 26 65 102 147 1378 80 1.0
y[8,2] 181 4.0e-01 15 156 181 205 1526 89 1.0
y[8,3] 479 8.9e-01 28 434 478 523 1002 58 1.0
y[9,1] 73 5.1e-01 19 45 72 106 1419 82 1.0
y[9,2] 128 3.5e-01 14 106 128 151 1514 88 1.0
y[9,3] 562 6.9e-01 23 525 562 597 1073 62 1.0
y[10,1] 58 4.1e-01 16 35 56 84 1452 84 1.0
y[10,2] 86 2.8e-01 11 69 86 105 1560 91 1.0
y[10,3] 619 5.3e-01 18 590 619 646 1127 65 1.0
y[11,1] 49 3.5e-01 14 29 48 73 1477 86 1.0
y[11,2] 57 2.1e-01 8.5 44 57 71 1609 94 1.0
y[11,3] 657 4.1e-01 14 633 658 679 1193 69 1.0
y[12,1] 44 3.2e-01 13 26 43 66 1496 87 1.0
y[12,2] 37 1.5e-01 6.3 27 37 47 1653 96 1.0
y[12,3] 682 3.4e-01 12 662 682 700 1296 75 1.0
y[13,1] 42 3.1e-01 12 24 40 62 1510 88 1.0
y[13,2] 24 1.1e-01 4.5 17 23 31 1689 98 1.0
y[13,3] 698 3.0e-01 11 679 699 714 1342 78 1.0
y[14,1] 40 3.0e-01 12 23 39 60 1520 88 1.0
y[14,2] 15 7.8e-02 3.2 10 15 21 1718 100 1.0
y[14,3] 708 2.9e-01 11 690 709 723 1387 81 1.0
phi 9.8 1.6e-01 5.8 3.5 8.4 20 1372 80 1.0
R0 3.2 8.7e-03 0.29 2.8 3.2 3.7 1091 63 1.0
recovery_time 1.9 3.8e-03 0.16 1.6 1.9 2.1 1630 95 1.0
pred_cases[1] 3.3 4.3e-02 2.2 0.00 3.0 7.0 2633 153 1.0
pred_cases[2] 11 1.2e-01 5.5 3.0 10 20 1964 114 1.00
pred_cases[3] 34 4.2e-01 17 13 32 63 1690 98 1.00
pred_cases[4] 90 9.7e-01 40 38 85 161 1717 100 1.0
pred_cases[5] 184 1.7e+00 81 79 173 324 2309 134 1.00
pred_cases[6] 244 1.8e+00 93 113 233 408 2791 162 1.0
pred_cases[7] 229 1.6e+00 87 106 220 386 2956 172 1.00
pred_cases[8] 181 1.3e+00 70 81 174 301 2922 170 1.00
pred_cases[9] 126 9.1e-01 49 55 121 212 2889 168 1.00
pred_cases[10] 87 7.6e-01 35 38 84 150 2166 126 1.0
pred_cases[11] 57 4.8e-01 24 24 54 101 2564 149 1.00
pred_cases[12] 37 3.0e-01 16 14 35 64 2724 158 1.0
pred_cases[13] 24 2.1e-01 11 9.0 22 43 2418 141 1.0
pred_cases[14] 15 1.5e-01 7.8 5.0 14 29 2653 154 1.00
Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at
convergence, R_hat=1).
Below are the diagnostics used to help validate the model.
Rhat values are below 1.1, see below:
library(rstan)
Loading required package: StanHeaders
rstan (Version 2.19.3, GitRev: 2e1f913d3ca3)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)
Attaching package: 'rstan'
The following object is masked from 'package:tidyr':
extract
r_stan_sir_negbin <- rstan::read_stan_csv(fit_sir_negbin$output_files())
stan_rhat(r_stan_sir_negbin)
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
stan_diag(r_stan_sir_negbin,information="divergence")
These graphs show that posteriors are similar across all 4 chains.
pars=c('beta', 'gamma', "R0", "recovery_time")
stan_dens(r_stan_sir_negbin, pars = pars, separate_chains = TRUE)