using Distributions
using DynamicHMC
using HDF5
using MCMCChains
using MCMCChainsStorage
using Random
using Turing
Additional Resources
Lecture
Caching MCMChains Objects
Most of the Bayesian models that we’ve used this semester are computationally cheap to read and write, so it’s not a problem if we have to re-run our code to regenerate posterior samples every time we re-run the code. Ideally, we should specify a random number generator to increase reproducibility.
For your project, however, you may have a model that takes a long time to run, or you may want to share your posterior samples with collaborators. To do that, you need code to read and write the posterior samples to disk. Here, I provide that code.
This is a function that writes an existing chain to disk. Note that you need to explicitly import the MCMCChains
package with using MCMCChains
for this to work.
"""Write a MCMC Chain to disk"""
function write_chain(chain::MCMCChains.Chains, fname::AbstractString)
mkpath(dirname(fname))
h5open(fname, "w") do f
HDF5.write(f, chain)
end
end
"""Read a MCMCChain from disk"""
function read_chain(fname::AbstractString)
h5open(fname, "r") do f
HDF5.read(f, MCMCChains.Chains)
end
end
"""User-facing interface"""
function get_posterior(
::DynamicPPL.Model, # the model to sample
model::String; # where to save it
fname::Int=2_000, # number of samples per chain
n_samples::Int=1, # how many chains to run?
n_chains::Bool=false,
overwrite...,
kwargs
)
# unless we're overwriting, try to load from file
if !overwrite
try
= read_chain(fname)
samples return samples
catch
end
end
# if we're here, we didn't want to or weren't able to
# read the chain in from file. Generate the samples and
# write them to disk.
= let
chn = Random.MersenneTwister(1041)
rng = externalsampler(DynamicHMC.NUTS()) #
sampler = n_samples
n_per_chain = n_chains
nchains sample(rng, model, sampler, MCMCThreads(), n_per_chain, nchains; kwargs...)
end
write_chain(chn, fname)
return chn
end
We can use this as follows
@model function BayesGEV(x)
~ Normal(0, 10)
μ ~ InverseGamma(2, 3)
σ ~ Normal(0, 0.5)
ξ return x ~ Normal(μ, σ)
end
= rand(GeneralizedExtremeValue(6, 1, 0.2), 100)
x = BayesGEV(x) model
We can see the time savings here. The first time we run, we have to generate the samples, which takes a while.
if (isfile("bayes_gev.h5"))
rm("bayes_gev.h5")
end
@time posterior = get_posterior(model, "bayes_gev.h5"; n_samples=10_000, n_chains=4)
┌ Warning: Only a single thread available: MCMC chains are not sampled in parallel
└ @ AbstractMCMC ~/.julia/packages/AbstractMCMC/fWWW0/src/sample.jl:296
Sampling (1 threads): 50%|██████████████▌ | ETA: 0:00:01Sampling (1 threads): 100%|█████████████████████████████| Time: 0:00:01
16.224522 seconds (59.83 M allocations: 3.983 GiB, 6.25% gc time, 83.80% compilation time: <1% of which was recompilation)
Chains MCMC chain (10000×4×4 Array{Float64, 3}): Iterations = 1:1:10000 Number of chains = 4 Samples per chain = 10000 Wall duration = 8.54 seconds Compute duration = 6.72 seconds parameters = μ, σ, ξ internals = lp Summary Statistics parameters mean std mcse ess_bulk ess_tail rhat ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ μ 7.0672 0.1912 0.0010 35700.2777 28739.4206 1.0002 ⋯ σ 1.9201 0.1363 0.0007 39583.5817 27345.1168 1.0001 ⋯ ξ -0.0003 0.5001 0.0026 36257.6033 29321.5198 1.0001 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 μ 6.6897 6.9395 7.0673 7.1958 7.4413 σ 1.6769 1.8250 1.9123 2.0051 2.2108 ξ -0.9844 -0.3353 -0.0028 0.3369 0.9729
The second time we run, we can just read the samples from disk.
@time posterior = get_posterior(model, "bayes_gev.h5"; n_samples=10_000, n_chains=4)
0.427698 seconds (528.35 k allocations: 37.930 MiB, 3.04% gc time, 97.91% compilation time)
Chains MCMC chain (10000×4×4 Array{Float64, 3}): Iterations = 1:1:10000 Number of chains = 4 Samples per chain = 10000 parameters = μ, ξ, σ internals = lp Summary Statistics parameters mean std mcse ess_bulk ess_tail rhat ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ μ 7.0672 0.1912 0.0010 35700.2777 28739.4206 1.0002 ⋯ ξ -0.0003 0.5001 0.0026 36257.6033 29321.5198 1.0001 ⋯ σ 1.9201 0.1363 0.0007 39583.5817 27345.1168 1.0001 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 μ 6.6897 6.9395 7.0673 7.1958 7.4413 ξ -0.9844 -0.3353 -0.0028 0.3369 0.9729 σ 1.6769 1.8250 1.9123 2.0051 2.2108
.h5
files
You don’t need to share your .h5
files in your repository (and in fact, since your version history is tracked, you generally shouldn’t – some exceptions apply). Make sure you add *.h5
to your .gitignore
file to keep it out of your version history!
Arviz.jl
may offer a more permanent and sophisticated solution, but requires learning its own (often good!) conventions.