OnlineSampling.jl is a Julia package for online Bayesian inference on reactive models, i.e., streaming probabilistic models.
Online sampling provides 1) a small macro based domain specific language to describe reactive models and 2) a semi-symbolic inference algorithm which combines exact solutions using Belief Propagation for trees of Gaussian random variables, and approximate solutions using Particle Filtering.
OnlineSampling is a probabilistic programming language that focuses on reactive models, i.e., streaming probabilistic models based on the synchronous model of execution. Programs execute synchronously in lockstep on a global discrete logical clock. Inputs and outputs are data streams, programs are stream processors. For such models, inference is a reactive process that returns the distribution of parameters at the current time step given the observations so far.
We use Julia's macro system to program reactive models in a style reminiscent of synchronous dataflow programming languages.
A stream function is introduced by the macro @node
.
Inside a node
, the macro @init
can be used to initialize a variable.
Another macro @prev
can then be used to access the value of a variable at the previous time step.
Then, the macro @nodeiter
turns a node into a Julia iterator which unfolds the execution of a node and returns the current value at each step.
For examples, the following function cpt
implements a simple counter incremented at each step, and prints its value
@node function cpt()
@init x = 0
x = @prev(x) + 1
return x
end
for x in @nodeiter T = 10 cpt()
println(x)
end
Reactive constructs @init
and @prev
can be mixed with probabilistic constructs to program reactive probabilistic models.
Following recent probabilistic languages (e.g., Turing.jl), probabilistic constructs are the following:
x = rand(D)
introduces a random variable x
with the prior distribution D
.@observe(x, v)
conditions the models assuming the random variable x
takes the value v
.For example, the following example is a HMM where we try to estimate the position of a moving agent from noisy observations.
speed = 1.0
noise = 0.5
@node function model()
@init x = rand(MvNormal([0.0], ScalMat(1, 1000.0))) # x_0 ~ N(0, 1000)
x = rand(MvNormal(@prev(x), ScalMat(1, speed))) # x_t ~ N(x_{t-1}, speed)
y = rand(MvNormal(x, ScalMat(1, noise))) # y_t ~ N(x_t, noise)
return x, y
end
@node function hmm(obs)
x, y = @nodecall model()
@observe(y, obs) # assume y_t is observed with value obs_t
return x
end
steps = 100
obs = rand(steps, 1)
cloud = @nodeiter particles = 1000 hmm(eachrow(obs)) # launch the inference with 1000 particles (return an iterator)
for (x, o) in zip(cloud, obs)
samples = rand(x, 1000) # sample the 1000 values from the posterior
println("Estimated: ", mean(samples), " Observation: ", o)
end
The inference method is a Rao-Blackwellised particle filter, a semi-symbolic algorithm which tries to analytically compute closed-form solutions, and falls back to a particle filter when symbolic computations fail. For Gaussian random variables with linear relations, we implemented belief propagation if the factor graph is a tree. As a result, in the previous HMM example, belief propagation is able to recover the equation of a Kalman filter and compute the exact solution and only one particle is necessary as shown below.
cloud = @noderun particles = 1 algo = belief_propagation hmm(eachrow(obs)) # launch the inference with 1 particles for all observations
d = dist(cloud.particles[1]) # distribution for the last state
This package relies on Julia's metaprogramming capabilities.
Under the hood, the macro @node
generates a stateful stream processor which closely mimic the Iterator
interface of Julia. The state correspond to the memory used to store all the variables accessed via @prev
.
The heavy lifting to create these functions is done by a Julia macro which acts on the Abstract Syntax Tree. The transformations at this level include, for t > 0
, adding the code to retrieve the previous internal state, update it and return it.
However, some transformations are best done at a later stage of the Julia pipeline.
One of them is the handling of calls to @prev
during the initial step t = 0
.
To seamlessly handle the various constructs of the Julia language, these calls are invalidated at the level of Intermediate Representation (IR) thanks to the package IRTools
.
Another operation at the IR level is the automatic realization of a symbolic variable undergoing an unsupported transform: when a function is applied to a random variable and there is no method matching the variable type, this variable is automatically sampled.
We also provide a "pointer-minimal" implementation of belief propagation: during execution when a random variables is not referenced anymore by the program, it can be freed by the garbage collector (GC).