A simple probabilistic Julia program

Consider the following Julia code:

using Gen: uniform_discrete, bernoulli, categorical

function f(p)
    n = uniform_discrete(1, 10)
    if bernoulli(p)
        n *= 2
    end
    return categorical([i == n ? 0.5 : 0.5/19 for i=1:20])
end;

The function f calls three functions provided by Gen, each of which returns a random value, sampled from a certain probability distribution:

  • uniform_discrete(a, b) returns an integer uniformly sampled from the set {a, .., b}

  • bernoulli(p) returns true with probability p and false with probability 1-p.

  • categorical(probs) returns the integer i with probability probs[i] for i in the set {1, .., length(probs)}.

These are three of the many probability distributions that are provided by Gen.

The function f first sets the initial value of n to a random value drawn from the set of integers {1, .., 10}:

    n = uniform_discrete(1, 10)

Then, with probability p, it multiplies by n by two:

    if bernoulli(p)
        n *= 2
    end

Then, it samples an integer in the set {1, ..., 20}. With probability 0.5 the integer is n, and with probability 0.5 it is uniformly chosen from the remaining 19 integers. It returns this sampled integer:

    return categorical([i == n ? 0.5 : 0.5/19 for i=1:20])

If we run this function many times, we can see the probability distribution on its return values. The distribution depends on the argument p to the function:

using PyPlot

bins = collect(range(0, 21))

function plot_histogram(p)
    hist([f(p) for _=1:100000], bins=bins)
    title("p = $p")
end

figure(figsize=(12, 2))

subplot(1, 3, 1)
plot_histogram(0.1)

subplot(1, 3, 2)
plot_histogram(0.5)

subplot(1, 3, 3)
plot_histogram(0.9);
png

Suppose we wanted to see what the distribution on return values would be if the initial value of n was 2. Because we don’t know what random values were sampled during a given execution, we can’t use simulations of f to answer this question. We would have to modify f first, to return the initial value of n:

function f_with_initial_n(p)
    initial_n = uniform_discrete(1, 10)
    n = initial_n
    if bernoulli(p)
        n *= 2
    end
    result = categorical([i == n ? 0.5 : 0.5/19 for i=1:20])
    return (result, initial_n)
end;

Then, we could only include executions in which our desired events did happen, when making our histogram:

function plot_histogram_filtered(p)
    executions = 0
    results = []
    while executions < 100000
        (result, initial_n) = f_with_initial_n(p)
        if initial_n == 2
            push!(results, result)
            executions += 1
        end
    end
    hist(results, bins=bins)
    title("p = $p")
end;

figure(figsize=(12, 2))

subplot(1, 3, 1)
plot_histogram_filtered(0.1)

subplot(1, 3, 2)
plot_histogram_filtered(0.5)

subplot(1, 3, 3)
plot_histogram_filtered(0.9);
png

Suppose we wanted to ask more questions. We might need to modify each time we have a new question, to make sure that the function returns the particular pieces of information about the execution that the question requires.

Note that if the function always returned the value of every random choice, then these values are sufficient to answer any question using executions of the function, because all states in the execution of the function are deterministic given the random choices. We will call the record of all the random choies a trace. In order to store all the random choices in the trace, we need to come up with a unique name or address for each random choice.

Below, we implement the trace as a dictionary that maps addresses of random choices to their values. We use a unique Julia Symbol for each address:

function f_with_trace(p)
    trace = Dict()
    
    initial_n = uniform_discrete(1, 10)
    trace[:initial_n] = initial_n
    
    n = initial_n
    
    do_branch = bernoulli(p)
    trace[:do_branch] = do_branch
    
    if do_branch
        n *= 2
    end
    
    result = categorical([i == n ? 0.5 : 0.5/19 for i=1:20])
    trace[:result] = result
    
    return (result, trace)
end;

We run the function, and get the return value and the trace:

f_with_trace(0.3)
(9, Dict{Any,Any}(:result=>9,:do_branch=>false,:initial_n=>9))

However, this program looks more complicated than the original program. We could make the syntax for tracing more concise:

function add_to_trace!(trace, value, address)
    trace[address] = value
    return value
end

function f_with_trace_improved(p)
    trace = Dict()
    n = add_to_trace!(trace, uniform_discrete(1, 10), :initial_n)
    if add_to_trace!(trace, bernoulli(p), :do_branch)
        n *= 2
    end
    result = add_to_trace!(trace, categorical([i == n ? 0.5 : 0.5/19 for i=1:20]), :result)
    return (result, trace)
end;

We run the function, and get the return value and the trace:

f_with_trace_improved(0.3)
(8, Dict{Any,Any}(:result=>8,:do_branch=>true,:initial_n=>10))

Now that we have instrumented the function, we can answer the following different question without needing to modify the function:

“What is the probability that the branch was taken, given that the result took the value 4?”

function query(p, observed_result_value::Int)
    executions = 0
    do_branch = []
    while executions < 100000
        (result, trace) = f_with_trace_improved(p)
        if trace[:result] == observed_result_value
            push!(do_branch, trace[:do_branch])
            executions += 1
        end
    end
    hist(do_branch, bins=[0, 1, 2], align="left")
    xticks([0, 1], ["false", "true"])
    title("p = $p")
end;

figure(figsize=(12, 2))

subplot(1, 3, 1)
query(0.1, 4)

subplot(1, 3, 2)
query(0.5, 4)

subplot(1, 3, 3)
query(0.9, 4);
png

What about a result value that is greater than 10?

figure(figsize=(12, 2))

subplot(1, 3, 1)
query(0.1, 14)

subplot(1, 3, 2)
query(0.5, 14)

subplot(1, 3, 3)
query(0.9, 14);
png

Last updated

Was this helpful?