OnlineSampling : online inference on reactive models

07/28/2022, 7:00 PM — 7:30 PM UTC
Red

Abstract:

OnlineSampling.jl is a Julia package for online Bayesian inference on reactive models, i.e., streaming probabilistic models.

Online sampling provides 1) a small macro based domain specific language to describe reactive models and 2) a semi-symbolic inference algorithm which combines exact solutions using Belief Propagation for trees of Gaussian random variables, and approximate solutions using Particle Filtering.

Description:

OnlineSampling is a probabilistic programming language that focuses on reactive models, i.e., streaming probabilistic models based on the synchronous model of execution. Programs execute synchronously in lockstep on a global discrete logical clock. Inputs and outputs are data streams, programs are stream processors. For such models, inference is a reactive process that returns the distribution of parameters at the current time step given the observations so far.

Synchronous Reactive Programming

We use Julia's macro system to program reactive models in a style reminiscent of synchronous dataflow programming languages.

A stream function is introduced by the macro @node. Inside a node, the macro @init can be used to initialize a variable. Another macro @prev can then be used to access the value of a variable at the previous time step.

Then, the macro @nodeiter turns a node into a Julia iterator which unfolds the execution of a node and returns the current value at each step.

For examples, the following function cpt implements a simple counter incremented at each step, and prints its value

@node function cpt() 
    @init x = 0 
    x = @prev(x) + 1 
    return x
end

for x in @nodeiter T = 10 cpt() 
    println(x)
end

Reactive Probabilistic Programming

Reactive constructs @init and @prev can be mixed with probabilistic constructs to program reactive probabilistic models.

Following recent probabilistic languages (e.g., Turing.jl), probabilistic constructs are the following:

  • x = rand(D) introduces a random variable x with the prior distribution D.
  • @observe(x, v) conditions the models assuming the random variable x takes the value v.

For example, the following example is a HMM where we try to estimate the position of a moving agent from noisy observations.

speed = 1.0
noise = 0.5
    
@node function model()
    @init x = rand(MvNormal([0.0], ScalMat(1, 1000.0))) # x_0 ~ N(0, 1000)
    x = rand(MvNormal(@prev(x), ScalMat(1, speed)))     # x_t ~ N(x_{t-1}, speed)
    y = rand(MvNormal(x, ScalMat(1, noise)))            # y_t ~ N(x_t, noise)
    return x, y
end
@node function hmm(obs)
    x, y = @nodecall model()
    @observe(y, obs) # assume y_t is observed with value obs_t 
    return x
end

steps = 100
obs = rand(steps, 1)
cloud = @nodeiter particles = 1000 hmm(eachrow(obs)) # launch the inference with 1000 particles (return an iterator)

for (x, o) in zip(cloud, obs)                            
    samples = rand(x, 1000)                                    # sample the 1000 values from the posterior     
    println("Estimated: ", mean(samples), " Observation: ", o)
end

Semi-symbolic algorithm

The inference method is a Rao-Blackwellised particle filter, a semi-symbolic algorithm which tries to analytically compute closed-form solutions, and falls back to a particle filter when symbolic computations fail. For Gaussian random variables with linear relations, we implemented belief propagation if the factor graph is a tree. As a result, in the previous HMM example, belief propagation is able to recover the equation of a Kalman filter and compute the exact solution and only one particle is necessary as shown below.

cloud = @noderun particles = 1 algo = belief_propagation hmm(eachrow(obs)) # launch the inference with 1 particles for all observations
d = dist(cloud.particles[1])                                               # distribution for the last state

Internals

This package relies on Julia's metaprogramming capabilities. Under the hood, the macro @node generates a stateful stream processor which closely mimic the Iterator interface of Julia. The state correspond to the memory used to store all the variables accessed via @prev.

The heavy lifting to create these functions is done by a Julia macro which acts on the Abstract Syntax Tree. The transformations at this level include, for t > 0, adding the code to retrieve the previous internal state, update it and return it.

However, some transformations are best done at a later stage of the Julia pipeline. One of them is the handling of calls to @prev during the initial step t = 0. To seamlessly handle the various constructs of the Julia language, these calls are invalidated at the level of Intermediate Representation (IR) thanks to the package IRTools.

Another operation at the IR level is the automatic realization of a symbolic variable undergoing an unsupported transform: when a function is applied to a random variable and there is no method matching the variable type, this variable is automatically sampled.

We also provide a "pointer-minimal" implementation of belief propagation: during execution when a random variables is not referenced anymore by the program, it can be freed by the garbage collector (GC).

Platinum sponsors

Julia ComputingRelational AIJulius Technology

Gold sponsors

IntelAWS

Silver sponsors

Invenia LabsBeacon BiosignalsMetalenzASMLG-ResearchConningPumas AIQuEra Computing Inc.Jeffrey Sarnoff

Media partners

Packt PublicationGather TownVercel

Community partners

Data UmbrellaWiMLDS

Fiscal Sponsor

NumFOCUS