Initial Issue
The Mixture node's out marginalisation rule was failing with a cryptic error:
MethodError: no method matching getlogscale(::Nothing)
I was a bit surprised by that and I haven't found the tests for mixture node out rules, maybe I missing smt here. So I wrote them on my own. In the current ReactiveMP it fails. Could it be I am misunderstading how this node can be used correctly?
@testset "Marginalisation: (m_switch::Categorical, m_inputs::ManyOf)" begin
@test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
# Test case 1: Equal weights
(
input = (
m_switch = Categorical([0.5, 0.5]),
m_inputs = ManyOf(
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
)
),
output = MixtureDistribution([
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
], [0.5, 0.5])
),
# Test case 2: Unequal weights
(
input = (
m_switch = Categorical([0.8, 0.2]),
m_inputs = ManyOf(
NormalMeanVariance(1.0, 1.0),
NormalMeanVariance(5.0, 2.0)
)
),
output = MixtureDistribution([
NormalMeanVariance(1.0, 1.0),
NormalMeanVariance(5.0, 2.0)
], [0.8, 0.2])
)
]
end
So I decided to re-write this rule myself.
I come up with the following implementation
@rule Mixture(:out, Marginalisation) (m_switch::Any, m_inputs::ManyOf{N, Any}) where {N} = begin
# Get logscales, defaulting to 0.0 if Nothing
logscales_inputs = map(msg -> getlogscale(getdata(msg)) === nothing ? 0.0 : getlogscale(getdata(msg)), messages[2])
logscale_switch = getlogscale(getdata(messages[1])) === nothing ? 0.0 : getlogscale(getdata(messages[1]))
# compute logscales of individual components
logscales = logscales_inputs .+ logscale_switch
@logscale logsumexp(logscales)
# Use probabilities directly from m_switch
w = probvec(m_switch)
T = promote_type(eltype(w), map(x -> eltype(mean(x)), m_inputs)...)
# Convert inputs to the promoted type
typed_inputs = map(x -> convert_paramfloattype(T, x), m_inputs)
# return mixture with type-preserved components
return MixtureDistribution(collect(typed_inputs), collect(w))
end
Also to make the rule work I need to write quite some helping methods, so to run test now, you need to use the following code
@testitem "rules:Mixture:out" begin
using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions
import ReactiveMP: @test_rules
import ReactiveMP: getlogscale
import BayesBase: paramfloattype
import Base: isapprox
using ExponentialFamily: NormalMeanVariance
function getlogscale(d::NormalMeanVariance{T}) where {T}
μ, τ = mean_precision(d)
# Log of normalization constant for Normal(μ, 1/√τ)
return -0.5 * (log(2π) - log(τ))
end
function getlogscale(d::Categorical{T}) where {T}
# Categorical distribution is already normalized
return 0.0
end
# Add paramfloattype for MixtureDistribution
function paramfloattype(d::MixtureDistribution{D, T}) where {D, T}
# The float type should be the promoted type of both the component distributions and weights
return promote_type(T, paramfloattype(first(d.components)))
end
# Add isapprox for MixtureDistribution
function isapprox(x::MixtureDistribution, y::MixtureDistribution; kwargs...)
# Check if components and weights match approximately
return length(x.components) == length(y.components) &&
all(isapprox.(x.components, y.components; kwargs...)) &&
isapprox(x.weights, y.weights; kwargs...)
end
@testset "Marginalisation: (m_switch::Categorical, m_inputs::ManyOf)" begin
@test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
# Test case 1: Equal weights
(
input = (
m_switch = Categorical([0.5, 0.5]),
m_inputs = ManyOf(
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
)
),
output = MixtureDistribution([
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
], [0.5, 0.5])
),
# Test case 2: Unequal weights
(
input = (
m_switch = Categorical([0.8, 0.2]),
m_inputs = ManyOf(
NormalMeanVariance(1.0, 1.0),
NormalMeanVariance(5.0, 2.0)
)
),
output = MixtureDistribution([
NormalMeanVariance(1.0, 1.0),
NormalMeanVariance(5.0, 2.0)
], [0.8, 0.2])
)
]
end
@testset "Marginalisation: (m_inputs::ManyOf, q_switch::PointMass)" begin
@test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
# Test case for one-hot encoded switch
(
input = (
m_inputs = ManyOf(
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
),
q_switch = PointMass([1.0, 0.0])
),
output = NormalMeanVariance(0.0, 1.0)
),
(
input = (
m_inputs = ManyOf(
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
),
q_switch = PointMass([0.0, 1.0])
),
output = NormalMeanVariance(2.0, 1.0)
)
]
end
end
Initially my goal was to run the following model
@model function mixture_model(y)
# Use fixed mixing proportions
s = [1/3, 2/3]
# Parameters for the two Gaussian components
m[1] ~ Normal(mean = -20.0, variance = 1e2)
w[1] ~ InverseGamma(2.0, 1.0) # Use InverseGamma for precision
m[2] ~ Normal(mean = 20.0, variance = 1e2)
w[2] ~ InverseGamma(2.0, 1.0) # Use InverseGamma for precision
obs_precision ~ InverseGamma(2.0, 1.0)
# Generate mixture assignments and observations
for i in eachindex(y)
z[i] ~ Categorical(s)
comp1[i] ~ NormalMeanVariance(m[1], w[1]) # Use variance parameterization
comp2[i] ~ NormalMeanVariance(m[2], w[2])
μ[i] ~ Mixture(switch = z[i], inputs = (comp1[i], comp2[i]))
y[i] ~ NormalMeanVariance(μ[i], obs_precision)
end
end
So the end-to-end script to run is the following one:
using RxInfer, Distributions
using Random
using Plots
using StatsPlots
using ReactiveMP: getlogscale
# Add getlogscale methods for both distribution types
import ReactiveMP: getlogscale
function getlogscale(d::NormalMeanPrecision{T}) where {T}
μ, τ = mean_precision(d)
# Log of normalization constant for Normal(μ, 1/√τ)
return -0.5 * (log(2π) - log(τ))
end
function getlogscale(d::NormalWeightedMeanPrecision{T}) where {T}
μ, τ = mean_precision(d)
# Log of normalization constant for Normal(μ, 1/√τ)
return -0.5 * (log(2π) - log(τ))
end
function getlogscale(d::NormalMeanVariance{T}) where {T}
μ, τ = mean_precision(d)
# Log of normalization constant for Normal(μ, 1/√τ)
return -0.5 * (log(2π) - log(τ))
end
function getlogscale(d::Categorical{T}) where {T}
# Categorical distribution is already normalized
return 0.0
end
function getlogscale(_::Nothing)
return 0.0
end
# include("mixture_entropy.jl") # include this to use free energy
@model function mixture_model(y)
# Use fixed mixing proportions
s = [1/3, 2/3]
# Parameters for the two Gaussian components
m[1] ~ Normal(mean = -20.0, variance = 1e2)
w[1] ~ InverseGamma(2.0, 1.0) # Use InverseGamma for precision
m[2] ~ Normal(mean = 20.0, variance = 1e2)
w[2] ~ InverseGamma(2.0, 1.0) # Use InverseGamma for precision
obs_precision ~ InverseGamma(2.0, 1.0)
# Generate mixture assignments and observations
for i in eachindex(y)
z[i] ~ Categorical(s)
comp1[i] ~ NormalMeanVariance(m[1], w[1]) # Use variance parameterization
comp2[i] ~ NormalMeanVariance(m[2], w[2])
μ[i] ~ Mixture(switch = z[i], inputs = (comp1[i], comp2[i]))
y[i] ~ NormalMeanVariance(μ[i], obs_precision)
end
end
# Update constraints
@constraints function mixture_constraints()
q(z, m, w, μ, comp1, comp2, obs_precision) = q(z)q(m)q(w)q(μ)q(comp1)q(comp2)q(obs_precision)
q(m) = q(m[1])q(m[2])
q(w) = q(w[1])q(w[2])
q(z) = q(z[1]) .. q(z[end])
q(μ) = q(μ[1]) .. q(μ[end])
q(comp1) = q(comp1[1]) .. q(comp1[end])
q(comp2) = q(comp2[1]) .. q(comp2[end])
end
# Generate synthetic data with unequal proportions
rng = MersenneTwister(42)
true_means = [-20.0, 20.0]
true_precisions = [1.0, 1.0]
N = 1000
switch = [1/3, 2/3] # Unequal proportions as in test
z = rand(rng, Categorical(switch), N)
data = zeros(N)
for i in 1:N
data[i] = randn(rng)/sqrt(true_precisions[z[i]]) + true_means[z[i]]
end
# Update initialization with better starting points
init = @initialization begin
# Initialize means further apart
q(m[1]) = NormalMeanVariance(-30.0, 10.0) # More uncertainty in initial means
q(m[2]) = NormalMeanVariance(30.0, 10.0)
# Initialize variances with more informative priors
q(w[1]) = InverseGamma(3.0, 2.0) # Mode around 1.0
q(w[2]) = InverseGamma(3.0, 2.0)
q(obs_precision) = InverseGamma(3.0, 2.0)
for i in 1:N
# Initialize assignments closer to true proportions
q(z[i]) = Categorical([0.4, 0.6])
# Initialize components with wider separation
q(comp1[i]) = NormalMeanVariance(-20.0, 5.0)
q(comp2[i]) = NormalMeanVariance(20.0, 5.0)
q(μ[i]) = NormalMeanVariance(0.0, 100.0) # Very uncertain about mixture means
end
end
result = infer(
model = mixture_model(),
constraints = mixture_constraints(),
initialization = init,
data = (y = data,),
iterations = 10,
allow_node_contraction = true,
options = (limit_stack_depth = 100,),
# free_energy = true
)
# Create a range for plotting
x_range = range(minimum(data) - 1, maximum(data) + 1, length=200)
# Get the final parameters
m1 = mean(result.posteriors[:m][1][end])
m2 = mean(result.posteriors[:m][2][end])
v1 = mean(result.posteriors[:w][1][end]) # This is variance now
v2 = mean(result.posteriors[:w][2][end])
# Create the plot
p = histogram(data, normalize=true, alpha=0.3, label="Data", bins=50)
plot!(x_range,
x -> pdf(Normal(m1, sqrt(v1)), x),
label="Component 1", linestyle=:dash)
plot!(x_range,
x -> pdf(Normal(m2, sqrt(v2)), x),
label="Component 2", linestyle=:dash)
title!("Gaussian Mixture Model Fit")
xlabel!("x")
ylabel!("Density")
# Save the plot
savefig(p, "mixture_fit.png")
# Print the fitted parameters
println("\nFitted Parameters:")
println("Mean 1: ", round(m1, digits=3))
println("Mean 2: ", round(m2, digits=3))
println("Variance 1: ", round(v1, digits=3))
println("Variance 2: ", round(v2, digits=3))
println("Observation variance: ", round(mean(result.posteriors[:obs_precision][end]), digits=3))
# plot(1:10, result.free_energy)
Interestingly it shows a different behavior (and I would say more interesting one) comparing it with NormalMixture

Normal mixture model (at least the following one) showing collapsing behavior:
using RxInfer, Distributions
using Random
using Plots
using StatsPlots
using ReactiveMP: getlogscale
@model function mixture_model_normal_mixture(y)
# Use fixed mixing proportions
s = [1/3, 2/3]
# Parameters for the two Gaussian components
m[1] ~ Normal(mean = -20.0, variance = 1e2)
w[1] ~ GammaShapeRate(2.0, 1.0)
m[2] ~ Normal(mean = 20.0, variance = 1e2)
w[2] ~ GammaShapeRate(2.0, 1.0)
# Generate mixture assignments and observations
for i in eachindex(y)
z[i] ~ Categorical(s)
# Using p for precision interface as shown in the tests
y[i] ~ NormalMixture(
switch = z[i],
m = (m[1], m[2]),
p = (w[1], w[2]) # Changed from v to p to match test code
)
end
end
@constraints function mixture_constraints()
q(z, m, w) = q(z)q(m)q(w)
q(m) = q(m[1])q(m[2])
q(w) = q(w[1])q(w[2])
q(z) = q(z[1]) .. q(z[end])
end
# Use same data generation
rng = MersenneTwister(42)
true_means = [-20.0, 20.0]
true_precisions = [1.0, 1.0]
N = 1000
switch = [1/3, 2/3]
z = rand(rng, Categorical(switch), N)
data = zeros(N)
for i in 1:N
data[i] = randn(rng)/sqrt(true_precisions[z[i]]) + true_means[z[i]]
end
# Use same initialization strategy
init = @initialization begin
q(m[1]) = NormalMeanVariance(-20.0, 5.0)
q(m[2]) = NormalMeanVariance(20.0, 5.0)
q(w[1]) = GammaShapeRate(3.0, 2.0)
q(w[2]) = GammaShapeRate(3.0, 2.0)
for i in 1:N
q(z[i]) = Categorical([0.4, 0.6])
end
end
result = infer(
model = mixture_model_normal_mixture(),
constraints = mixture_constraints(),
initialization = init,
data = (y = data,),
iterations = 10,
allow_node_contraction = true,
options = (limit_stack_depth = 100,),
free_energy = true
)
# Plotting
x_range = range(minimum(data) - 1, maximum(data) + 1, length=200)
m1 = mean(result.posteriors[:m][1][end])
m2 = mean(result.posteriors[:m][2][end])
v1 = mean(result.posteriors[:w][1][end])
v2 = mean(result.posteriors[:w][2][end])
p = histogram(data, normalize=true, alpha=0.3, label="Data", bins=50)
plot!(x_range,
x -> pdf(Normal(m1, sqrt(v1)), x),
label="Component 1", linestyle=:dash)
plot!(x_range,
x -> pdf(Normal(m2, sqrt(v2)), x),
label="Component 2", linestyle=:dash)
title!("Gaussian Mixture Model Fit (NormalMixture)")
xlabel!("x")
ylabel!("Density")
savefig(p, "normal_mixture_fit.png")
println("\nFitted Parameters (NormalMixture):")
println("Mean 1: ", round(m1, digits=3))
println("Mean 2: ", round(m2, digits=3))
println("Variance 1: ", round(v1, digits=3))
println("Variance 2: ", round(v2, digits=3))
@show result.free_energy[end]
plot(1:10, result.free_energy)

Initial Issue
The Mixture node's out marginalisation rule was failing with a cryptic error:
I was a bit surprised by that and I haven't found the tests for mixture node out rules, maybe I missing smt here. So I wrote them on my own. In the current
ReactiveMPit fails. Could it be I am misunderstading how this node can be used correctly?So I decided to re-write this rule myself.
I come up with the following implementation
Also to make the rule work I need to write quite some helping methods, so to run test now, you need to use the following code
Initially my goal was to run the following model
So the end-to-end script to run is the following one:
Interestingly it shows a different behavior (and I would say more interesting one) comparing it with
NormalMixtureNormal mixture model (at least the following one) showing collapsing behavior: