diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index b555f6bd9f..79979ac83a 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -62,32 +62,117 @@ TEST(HybridBayesNet, Add) { } /* ****************************************************************************/ -// Test evaluate for a pure discrete Bayes net P(Asia). +// Test API for a pure discrete Bayes net P(Asia). TEST(HybridBayesNet, EvaluatePureDiscrete) { HybridBayesNet bayesNet; - bayesNet.emplace_shared(Asia, "4/6"); - HybridValues values; - values.insert(asiaKey, 0); - EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9); + const auto pAsia = std::make_shared(Asia, "4/6"); + bayesNet.push_back(pAsia); + HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}}; + + // choose + GaussianBayesNet empty; + EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9)); + + // evaluate + EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9); + + // optimize + EXPECT(assert_equal(one, bayesNet.optimize())); + EXPECT(assert_equal(VectorValues{}, bayesNet.optimize(one.discrete()))); + + // sample + std::mt19937_64 rng(42); + EXPECT(assert_equal(zero, bayesNet.sample(&rng))); + EXPECT(assert_equal(one, bayesNet.sample(one, &rng))); + EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng))); + + // error + EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); + + // logProbability + EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); + + // toFactorGraph + HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({}); + EXPECT(assert_equal(expectedFG, fg)); + + // prune, imperative :-( + EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); + EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size()); } /* ****************************************************************************/ // Test creation of a tiny hybrid Bayes net. TEST(HybridBayesNet, Tiny) { - auto bn = tiny::createHybridBayesNet(); - EXPECT_LONGS_EQUAL(3, bn.size()); + auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode) + EXPECT_LONGS_EQUAL(3, bayesNet.size()); const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}}; - auto fg = bn.toFactorGraph(vv); + HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}}; + + // Check Invariants for components + HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid(); + GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()), + gc1 = hgc->choose(one.discrete()); + GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian(); + GaussianConditional::CheckInvariants(*gc0, vv); + GaussianConditional::CheckInvariants(*gc1, vv); + GaussianConditional::CheckInvariants(*px, vv); + HybridGaussianConditional::CheckInvariants(*hgc, zero); + HybridGaussianConditional::CheckInvariants(*hgc, one); + + // choose + GaussianBayesNet expectedChosen; + expectedChosen.push_back(gc0); + expectedChosen.push_back(px); + auto chosen0 = bayesNet.choose(zero.discrete()); + auto chosen1 = bayesNet.choose(one.discrete()); + EXPECT(assert_equal(expectedChosen, chosen0, 1e-9)); + + // logProbability + const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior + const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior + EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9); + + // evaluate + EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9); + + // optimize + EXPECT(assert_equal(one, bayesNet.optimize())); + EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete()))); + + // sample + std::mt19937_64 rng(42); + EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete())); + + // error + const double error0 = chosen0.error(vv) + gc0->negLogConstant() - + px->negLogConstant() - log(0.4); + const double error1 = chosen1.error(vv) + gc1->negLogConstant() - + px->negLogConstant() - log(0.6); + EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9); + EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9); + + // toFactorGraph + auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}}); EXPECT_LONGS_EQUAL(3, fg.size()); // Check that the ratio of probPrime to evaluate is the same for all modes. std::vector ratio(2); - for (size_t mode : {0, 1}) { - const HybridValues hv{vv, {{M(0), mode}}}; - ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv); - } + ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); + ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); + + // prune, imperative :-( + auto pruned = bayesNet.prune(1); + EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); + EXPECT(!pruned.equals(bayesNet)); + } /* ****************************************************************************/ @@ -223,12 +308,15 @@ TEST(HybridBayesNet, Optimize) { /* ****************************************************************************/ // Test Bayes net error TEST(HybridBayesNet, Pruning) { + // Create switching network with three continuous variables and two discrete: + // ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) Switching s(3); HybridBayesNet::shared_ptr posterior = s.linearizedFactorGraph.eliminateSequential(); EXPECT_LONGS_EQUAL(5, posterior->size()); + // Optimize HybridValues delta = posterior->optimize(); auto actualTree = posterior->evaluate(delta.continuous()); @@ -254,7 +342,6 @@ TEST(HybridBayesNet, Pruning) { logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues); logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues); logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues); - // NOTE(dellaert): the discrete errors were not added in logProbability tree! logProbability += posterior->at(3)->asDiscrete()->logProbability(hybridValues); logProbability += @@ -316,10 +403,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { #endif // regression - DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}}; DecisionTreeFactor::ADT potentials( - dkeys, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); - DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials); + s.modes, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); + DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials); // Prune! posterior->prune(maxNrLeaves);