From e15c44ec5c711f9efae695c0d03fa3b8ed24bbb8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 22:56:41 -0700 Subject: [PATCH 01/22] Make prune functional --- gtsam/hybrid/HybridGaussianConditional.cpp | 50 ++++++++++++++++++++-- gtsam/hybrid/HybridGaussianConditional.h | 4 +- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 1db13e95b3..3c5130f428 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -342,13 +342,57 @@ HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) { } /* *******************************************************************************/ -void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) { +HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( + const DecisionTreeFactor &discreteProbs) const { + auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); + auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys()); + // Functional which loops over all assignments and create a set of // GaussianConditionals - auto pruner = prunerFunc(discreteProbs); + auto pruner = [&](const Assignment &choices, + const GaussianConditional::shared_ptr &conditional) + -> GaussianConditional::shared_ptr { + // typecast so we can use this to get probability value + const DiscreteValues values(choices); + + // Case where the hybrid gaussian conditional has the same + // discrete keys as the decision tree. + if (hybridGaussianCondKeySet == discreteProbsKeySet) { + if (discreteProbs(values) == 0.0) { + // empty aka null pointer + std::shared_ptr null; + return null; + } else { + return conditional; + } + } else { + std::vector set_diff; + std::set_difference( + discreteProbsKeySet.begin(), discreteProbsKeySet.end(), + hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(), + std::back_inserter(set_diff)); + + const std::vector assignments = + DiscreteValues::CartesianProduct(set_diff); + for (const DiscreteValues &assignment : assignments) { + DiscreteValues augmented_values(values); + augmented_values.insert(assignment); + + // If any one of the sub-branches are non-zero, + // we need this conditional. + if (discreteProbs(augmented_values) > 0.0) { + return conditional; + } + } + // If we are here, it means that all the sub-branches are 0, + // so we prune. + return nullptr; + } + }; auto pruned_conditionals = conditionals_.apply(pruner); - conditionals_.root_ = pruned_conditionals.root_; + return std::make_shared(discreteKeys(), + pruned_conditionals); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 68c63e7bd7..ede748b162 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -225,8 +225,10 @@ class GTSAM_EXPORT HybridGaussianConditional * `discreteProbs`. * * @param discreteProbs A pruned set of probabilities for the discrete keys. + * @return Shared pointer to possibly a pruned HybridGaussianConditional */ - void prune(const DecisionTreeFactor &discreteProbs); + HybridGaussianConditional::shared_ptr prune( + const DecisionTreeFactor &discreteProbs) const; /// @} From d2880e991396967851f84be83b8ed56f3af8207f Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 23:07:36 -0700 Subject: [PATCH 02/22] Kill obsolete prunerFunc --- gtsam/hybrid/HybridBayesNet.cpp | 112 ++++---------------------------- gtsam/hybrid/HybridBayesNet.h | 9 ++- 2 files changed, 20 insertions(+), 101 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 3c77e3f9aa..703c657cfe 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -37,94 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { return Base::equals(bn, tol); } -/* ************************************************************************* */ -/** - * @brief Helper function to get the pruner functional. - * - * @param prunedDiscreteProbs The prob. decision tree of only discrete keys. - * @param conditional Conditional to prune. Used to get full assignment. - * @return std::function &, double)> - */ -std::function &, double)> prunerFunc( - const DecisionTreeFactor &prunedDiscreteProbs, - const HybridConditional &conditional) { - // Get the discrete keys as sets for the decision tree - // and the hybrid Gaussian conditional. - std::set discreteProbsKeySet = - DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys()); - std::set conditionalKeySet = - DiscreteKeysAsSet(conditional.discreteKeys()); - - auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet]( - const Assignment &choices, - double probability) -> double { - // This corresponds to 0 probability - double pruned_prob = 0.0; - - // typecast so we can use this to get probability value - DiscreteValues values(choices); - // Case where the hybrid Gaussian conditional has the same - // discrete keys as the decision tree. - if (conditionalKeySet == discreteProbsKeySet) { - if (prunedDiscreteProbs(values) == 0) { - return pruned_prob; - } else { - return probability; - } - } else { - // Due to branch merging (aka pruning) in DecisionTree, it is possible we - // get a `values` which doesn't have the full set of keys. - std::set valuesKeys; - for (auto kvp : values) { - valuesKeys.insert(kvp.first); - } - std::set conditionalKeys; - for (auto kvp : conditionalKeySet) { - conditionalKeys.insert(kvp.first); - } - // If true, then values is missing some keys - if (conditionalKeys != valuesKeys) { - // Get the keys present in conditionalKeys but not in valuesKeys - std::vector missing_keys; - std::set_difference(conditionalKeys.begin(), conditionalKeys.end(), - valuesKeys.begin(), valuesKeys.end(), - std::back_inserter(missing_keys)); - // Insert missing keys with a default assignment. - for (auto missing_key : missing_keys) { - values[missing_key] = 0; - } - } - - // Now we generate the full assignment by enumerating - // over all keys in the prunedDiscreteProbs. - // First we find the differing keys - std::vector set_diff; - std::set_difference(discreteProbsKeySet.begin(), - discreteProbsKeySet.end(), conditionalKeySet.begin(), - conditionalKeySet.end(), - std::back_inserter(set_diff)); - - // Now enumerate over all assignments of the differing keys - const std::vector assignments = - DiscreteValues::CartesianProduct(set_diff); - for (const DiscreteValues &assignment : assignments) { - DiscreteValues augmented_values(values); - augmented_values.insert(assignment); - - // If any one of the sub-branches are non-zero, - // we need this probability. - if (prunedDiscreteProbs(augmented_values) > 0.0) { - return probability; - } - } - // If we are here, it means that all the sub-branches are 0, - // so we prune. - return pruned_prob; - } - }; - return pruner; -} - /* ************************************************************************* */ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( size_t maxNrLeaves) { @@ -164,9 +76,10 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( } /* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { + HybridBayesNet copy(*this); DecisionTreeFactor prunedDiscreteProbs = - this->pruneDiscreteConditionals(maxNrLeaves); + copy.pruneDiscreteConditionals(maxNrLeaves); /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree @@ -179,13 +92,10 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Go through all the conditionals in the // Bayes Net and prune them as per prunedDiscreteProbs. - for (auto &&conditional : *this) { + for (auto &&conditional : copy) { if (auto gm = conditional->asHybrid()) { // Make a copy of the hybrid Gaussian conditional and prune it! - auto prunedHybridGaussianConditional = - std::make_shared(*gm); - prunedHybridGaussianConditional->prune( - prunedDiscreteProbs); // imperative :-( + auto prunedHybridGaussianConditional = gm->prune(prunedDiscreteProbs); // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back(prunedHybridGaussianConditional); @@ -336,10 +246,14 @@ AlgebraicDecisionTree HybridBayesNet::logProbability( }); } else if (auto dc = conditional->asDiscrete()) { // If discrete, add the discrete logProbability in the right branch - result = result.apply( - [dc](const Assignment &assignment, double leaf_value) { - return leaf_value + dc->logProbability(DiscreteValues(assignment)); - }); + if (result.nrLeaves() == 1) { + result = dc->errorTree().apply([](double error) { return -error; }); + } else { + result = result.apply([dc](const Assignment &assignment, + double leaf_value) { + return leaf_value + dc->logProbability(DiscreteValues(assignment)); + }); + } } } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 62688e8b20..9052a7a167 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -201,8 +201,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ HybridValues sample() const; - /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. - HybridBayesNet prune(size_t maxNrLeaves); + /** + * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. + * + * @param maxNrLeaves Continuous values at which to compute the error. + * @return A pruned HybridBayesNet + */ + HybridBayesNet prune(size_t maxNrLeaves) const; /** * @brief Compute conditional error for each discrete assignment, From 28f5ed0a6edca9e00b8b74054f7f34b2a2da0caf Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 23:08:00 -0700 Subject: [PATCH 03/22] Inline lambda --- gtsam/hybrid/HybridGaussianConditional.cpp | 66 ++-------------------- gtsam/hybrid/HybridGaussianConditional.h | 11 ---- 2 files changed, 4 insertions(+), 73 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 3c5130f428..478f94f180 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -288,59 +288,6 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { return s; } -/* ************************************************************************* */ -std::function &, const GaussianConditional::shared_ptr &)> -HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) { - // Get the discrete keys as sets for the decision tree - // and the hybrid gaussian conditional. - auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); - auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys()); - - auto pruner = [discreteProbs, discreteProbsKeySet, hybridGaussianCondKeySet]( - const Assignment &choices, - const GaussianConditional::shared_ptr &conditional) - -> GaussianConditional::shared_ptr { - // typecast so we can use this to get probability value - const DiscreteValues values(choices); - - // Case where the hybrid gaussian conditional has the same - // discrete keys as the decision tree. - if (hybridGaussianCondKeySet == discreteProbsKeySet) { - if (discreteProbs(values) == 0.0) { - // empty aka null pointer - std::shared_ptr null; - return null; - } else { - return conditional; - } - } else { - std::vector set_diff; - std::set_difference( - discreteProbsKeySet.begin(), discreteProbsKeySet.end(), - hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(), - std::back_inserter(set_diff)); - - const std::vector assignments = - DiscreteValues::CartesianProduct(set_diff); - for (const DiscreteValues &assignment : assignments) { - DiscreteValues augmented_values(values); - augmented_values.insert(assignment); - - // If any one of the sub-branches are non-zero, - // we need this conditional. - if (discreteProbs(augmented_values) > 0.0) { - return conditional; - } - } - // If we are here, it means that all the sub-branches are 0, - // so we prune. - return nullptr; - } - }; - return pruner; -} - /* *******************************************************************************/ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( const DecisionTreeFactor &discreteProbs) const { @@ -358,14 +305,10 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( // Case where the hybrid gaussian conditional has the same // discrete keys as the decision tree. if (hybridGaussianCondKeySet == discreteProbsKeySet) { - if (discreteProbs(values) == 0.0) { - // empty aka null pointer - std::shared_ptr null; - return null; - } else { - return conditional; - } + return (discreteProbs(values) == 0.0) ? nullptr : conditional; } else { + // TODO(Frank): It might be faster to "choose" based on values + // and then check whether the resulting tree has non-nullptrs. std::vector set_diff; std::set_difference( discreteProbsKeySet.begin(), discreteProbsKeySet.end(), @@ -384,8 +327,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( return conditional; } } - // If we are here, it means that all the sub-branches are 0, - // so we prune. + // If we are here, it means that all the sub-branches are 0, so we prune. return nullptr; } }; diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index ede748b162..8f3aa67788 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -243,17 +243,6 @@ class GTSAM_EXPORT HybridGaussianConditional /// Convert to a DecisionTree of Gaussian factor graphs. GaussianFactorGraphTree asGaussianFactorGraphTree() const; - /** - * @brief Get the pruner function from discrete probabilities. - * - * @param discreteProbs The probabilities of only discrete keys. - * @return std::function &, const GaussianConditional::shared_ptr &)> - */ - std::function &, const GaussianConditional::shared_ptr &)> - prunerFunc(const DecisionTreeFactor &prunedProbabilities); - /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const; From 38ed6096145337853f5ccbd80a76d77fab9b64a4 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 23:15:19 -0700 Subject: [PATCH 04/22] Fix pruning in iSAM --- gtsam/hybrid/HybridBayesTree.cpp | 8 +++++++- gtsam/inference/BayesTreeCliqueBase.h | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 9aee6dcf81..0766f452b3 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -26,6 +26,10 @@ #include #include +#include + +#include "gtsam/hybrid/HybridConditional.h" + namespace gtsam { // Instantiate base class @@ -207,7 +211,9 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { if (conditional->isHybrid()) { auto hybridGaussianCond = conditional->asHybrid(); - hybridGaussianCond->prune(parentData.prunedDiscreteProbs); + // Imperative + clique->conditional() = std::make_shared( + hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); } return parentData; } diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index a7b1cf06c1..adffa2f146 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -140,6 +140,9 @@ namespace gtsam { /** Access the conditional */ const sharedConditional& conditional() const { return conditional_; } + /** Write access to the conditional */ + sharedConditional& conditional() { return conditional_; } + /// Return true if this clique is the root of a Bayes tree. inline bool isRoot() const { return parent_.expired(); } From a898ad3661e14c76d338476bfe46dfa94c7de8ed Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 11:54:54 -0700 Subject: [PATCH 05/22] discretePosterior --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 25 +++++++++---- gtsam/hybrid/HybridGaussianFactorGraph.h | 28 ++++++++++----- .../tests/testHybridGaussianFactorGraph.cpp | 35 +++++++++++-------- 3 files changed, 58 insertions(+), 30 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8a2a7fd158..8e6e95c17d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -42,7 +43,6 @@ #include #include #include -#include #include #include #include @@ -342,14 +342,20 @@ static std::shared_ptr createHybridGaussianFactor( return std::make_shared(discreteSeparator, newFactors); } +/* *******************************************************************************/ +/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys. +static auto GetDiscreteKeys = + [](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys { + const std::set discreteKeySet = hfg.discreteKeys(); + return {discreteKeySet.begin(), discreteKeySet.end()}; +}; + /* *******************************************************************************/ std::pair> HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // Since we eliminate all continuous variables first, // the discrete separator will be *all* the discrete keys. - const std::set keysForDiscreteVariables = discreteKeys(); - DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(), - keysForDiscreteVariables.end()); + DiscreteKeys discreteSeparator = GetDiscreteKeys(*this); // Collect all the factors to create a set of Gaussian factor graphs in a // decision tree indexed by all discrete keys involved. @@ -525,14 +531,21 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { } /* ************************************************************************ */ -AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( +DecisionTreeFactor HybridGaussianFactorGraph::probPrime( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree = this->errorTree(continuousValues); AlgebraicDecisionTree prob_tree = error_tree.apply([](double error) { // NOTE: The 0.5 term is handled by each factor return exp(-error); }); - return prob_tree; + return {GetDiscreteKeys(*this), prob_tree}; +} + +/* ************************************************************************ */ +DiscreteConditional HybridGaussianFactorGraph::discretePosterior( + const VectorValues &continuousValues) const { + auto p = probPrime(continuousValues); + return {p.size(), p}; } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 7e3aac663d..5d19b4f837 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -187,24 +188,33 @@ class GTSAM_EXPORT HybridGaussianFactorGraph AlgebraicDecisionTree errorTree( const VectorValues& continuousValues) const; + /** + * @brief Compute the unnormalized posterior probability for a continuous + * vector values given a specific assignment. + * + * @return double + */ + double probPrime(const HybridValues& values) const; + /** * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * for each discrete assignment, and return as a tree. * - * @param continuousValues Continuous values at which to compute the - * probability. - * @return AlgebraicDecisionTree + * @param continuousValues Continuous values at which to compute probability. + * @return DecisionTreeFactor */ - AlgebraicDecisionTree probPrime( - const VectorValues& continuousValues) const; + DecisionTreeFactor probPrime(const VectorValues& continuousValues) const; /** - * @brief Compute the unnormalized posterior probability for a continuous - * vector values given a specific assignment. + * @brief Computer posterior P(M|X=x) when all continuous values X are given. + * This is very efficient as this simply probPrime normalized into a + * conditional. * - * @return double + * @param continuousValues Continuous values x to condition on. + * @return DecisionTreeFactor */ - double probPrime(const HybridValues& values) const; + DiscreteConditional discretePosterior( + const VectorValues& continuousValues) const; /** * @brief Create a decision tree of factor graphs out of this hybrid factor diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 6aef603868..8ba1eb7622 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -603,29 +603,34 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { /* ****************************************************************************/ // Test hybrid gaussian factor graph error and unnormalized probabilities TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { + // Create switching network with three continuous variables and two discrete: + // ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) Switching s(3); - HybridGaussianFactorGraph graph = s.linearizedFactorGraph; + const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph; - HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); + const HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); - HybridValues delta = hybridBayesNet->optimize(); - auto error_tree = graph.errorTree(delta.continuous()); + const HybridValues delta = hybridBayesNet->optimize(); - std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; + // regression test for errorTree std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; - AlgebraicDecisionTree expected_error(discrete_keys, leaves); - - // regression - EXPECT(assert_equal(expected_error, error_tree, 1e-7)); + AlgebraicDecisionTree expectedErrors(s.modes, leaves); + const auto error_tree = graph.errorTree(delta.continuous()); + EXPECT(assert_equal(expectedErrors, error_tree, 1e-7)); + // regression test for probPrime + const DecisionTreeFactor expectedFactor( + s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064}); auto probabilities = graph.probPrime(delta.continuous()); - std::vector prob_leaves = {0.36793249, 0.61247742, 0.59489556, - 0.99029064}; - AlgebraicDecisionTree expected_probabilities(discrete_keys, prob_leaves); - - // regression - EXPECT(assert_equal(expected_probabilities, probabilities, 1e-7)); + EXPECT(assert_equal(expectedFactor, probabilities, 1e-7)); + + // regression test for discretePosterior + const DecisionTreeFactor normalized( + s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852}); + DiscreteConditional expectedPosterior(2, normalized); + auto posterior = graph.discretePosterior(delta.continuous()); + EXPECT(assert_equal(expectedPosterior, posterior, 1e-7)); } /* ****************************************************************************/ From 2abb41059253dd177ad54d2611c9c30dc2d833f6 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 16:36:31 -0700 Subject: [PATCH 06/22] Removed ill-named and confusing method logProbability --- gtsam/hybrid/HybridGaussianConditional.cpp | 18 ------------------ gtsam/hybrid/HybridGaussianConditional.h | 10 ---------- .../tests/testHybridGaussianConditional.cpp | 11 ----------- 3 files changed, 39 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 478f94f180..1c3a69ce7d 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -337,24 +337,6 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( pruned_conditionals); } -/* *******************************************************************************/ -AlgebraicDecisionTree HybridGaussianConditional::logProbability( - const VectorValues &continuousValues) const { - // functor to calculate (double) logProbability value from - // GaussianConditional. - auto probFunc = - [continuousValues](const GaussianConditional::shared_ptr &conditional) { - if (conditional) { - return conditional->logProbability(continuousValues); - } else { - // Return arbitrarily small logProbability if conditional is null - // Conditional is null if it is pruned out. - return -1e20; - } - }; - return DecisionTree(conditionals_, probFunc); -} - /* *******************************************************************************/ double HybridGaussianConditional::logProbability( const HybridValues &values) const { diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 8f3aa67788..f3bf4d839e 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -194,16 +194,6 @@ class GTSAM_EXPORT HybridGaussianConditional /// Getter for the underlying Conditionals DecisionTree const Conditionals &conditionals() const; - /** - * @brief Compute logProbability of the HybridGaussianConditional as a tree. - * - * @param continuousValues The continuous VectorValues. - * @return AlgebraicDecisionTree A decision tree with the same keys - * as the conditionals, and leaf values as the logProbability. - */ - AlgebraicDecisionTree logProbability( - const VectorValues &continuousValues) const; - /** * @brief Compute the logProbability of this hybrid Gaussian conditional. * diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 803d42f034..24eb409a1c 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -74,17 +74,6 @@ TEST(HybridGaussianConditional, Invariants) { /// Check LogProbability. TEST(HybridGaussianConditional, LogProbability) { using namespace equal_constants; - auto actual = hybrid_conditional.logProbability(vv); - - // Check result. - std::vector discrete_keys = {mode}; - std::vector leaves = {conditionals[0]->logProbability(vv), - conditionals[1]->logProbability(vv)}; - AlgebraicDecisionTree expected(discrete_keys, leaves); - - EXPECT(assert_equal(expected, actual, 1e-6)); - - // Check for non-tree version. for (size_t mode : {0, 1}) { const HybridValues hv{vv, {{M(0), mode}}}; EXPECT_DOUBLES_EQUAL(conditionals[mode]->logProbability(vv), From 64513eb6d9a8d46c415d837e363389ecd7295c6d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 16:37:02 -0700 Subject: [PATCH 07/22] discretePosterior for graphs --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 23 +++++++------------ gtsam/hybrid/HybridGaussianFactorGraph.h | 17 ++++---------- .../tests/testHybridGaussianFactorGraph.cpp | 9 +------- 3 files changed, 14 insertions(+), 35 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8e6e95c17d..0e5a34359e 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -505,22 +505,22 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree(0.0); + AlgebraicDecisionTree result(0.0); // Iterate over each factor. for (auto &factor : factors_) { if (auto f = std::dynamic_pointer_cast(factor)) { // Check for HybridFactor, and call errorTree - error_tree = error_tree + f->errorTree(continuousValues); + result = result + f->errorTree(continuousValues); } else if (auto f = std::dynamic_pointer_cast(factor)) { // Skip discrete factors continue; } else { // Everything else is a continuous only factor HybridValues hv(continuousValues, DiscreteValues()); - error_tree = error_tree + AlgebraicDecisionTree(factor->error(hv)); + result = result + AlgebraicDecisionTree(factor->error(hv)); } } - return error_tree; + return result; } /* ************************************************************************ */ @@ -531,21 +531,14 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { } /* ************************************************************************ */ -DecisionTreeFactor HybridGaussianFactorGraph::probPrime( +AlgebraicDecisionTree HybridGaussianFactorGraph::discretePosterior( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree = this->errorTree(continuousValues); - AlgebraicDecisionTree prob_tree = error_tree.apply([](double error) { + AlgebraicDecisionTree errors = this->errorTree(continuousValues); + AlgebraicDecisionTree p = errors.apply([](double error) { // NOTE: The 0.5 term is handled by each factor return exp(-error); }); - return {GetDiscreteKeys(*this), prob_tree}; -} - -/* ************************************************************************ */ -DiscreteConditional HybridGaussianFactorGraph::discretePosterior( - const VectorValues &continuousValues) const { - auto p = probPrime(continuousValues); - return {p.size(), p}; + return p / p.sum(); } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 5d19b4f837..3ef6218bec 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -196,24 +196,17 @@ class GTSAM_EXPORT HybridGaussianFactorGraph */ double probPrime(const HybridValues& values) const; - /** - * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ - * for each discrete assignment, and return as a tree. - * - * @param continuousValues Continuous values at which to compute probability. - * @return DecisionTreeFactor - */ - DecisionTreeFactor probPrime(const VectorValues& continuousValues) const; - /** * @brief Computer posterior P(M|X=x) when all continuous values X are given. - * This is very efficient as this simply probPrime normalized into a - * conditional. + * This is efficient as this simply probPrime normalized. + * + * @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys, + * which we would need, are hard to recover. * * @param continuousValues Continuous values x to condition on. * @return DecisionTreeFactor */ - DiscreteConditional discretePosterior( + AlgebraicDecisionTree discretePosterior( const VectorValues& continuousValues) const; /** diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 8ba1eb7622..0c5f52e611 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -619,16 +619,9 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { const auto error_tree = graph.errorTree(delta.continuous()); EXPECT(assert_equal(expectedErrors, error_tree, 1e-7)); - // regression test for probPrime - const DecisionTreeFactor expectedFactor( - s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064}); - auto probabilities = graph.probPrime(delta.continuous()); - EXPECT(assert_equal(expectedFactor, probabilities, 1e-7)); - // regression test for discretePosterior - const DecisionTreeFactor normalized( + const AlgebraicDecisionTree expectedPosterior( s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852}); - DiscreteConditional expectedPosterior(2, normalized); auto posterior = graph.discretePosterior(delta.continuous()); EXPECT(assert_equal(expectedPosterior, posterior, 1e-7)); } From 788f4b6a199009c8efb683699c0b6ff7c450625c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 16:37:30 -0700 Subject: [PATCH 08/22] renamed logProbability and added discretePosterior --- gtsam/hybrid/HybridBayesNet.cpp | 38 ++++--- gtsam/hybrid/HybridBayesNet.h | 33 +++--- gtsam/hybrid/tests/testHybridBayesNet.cpp | 132 +++++++--------------- 3 files changed, 85 insertions(+), 118 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 703c657cfe..5f655c9902 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -214,10 +214,14 @@ AlgebraicDecisionTree HybridBayesNet::errorTree( } else if (auto dc = conditional->asDiscrete()) { // If discrete, add the discrete error in the right branch - result = result.apply( - [dc](const Assignment &assignment, double leaf_value) { - return leaf_value + dc->error(DiscreteValues(assignment)); - }); + if (result.nrLeaves() == 1) { + result = dc->errorTree(); + } else { + result = result.apply( + [dc](const Assignment &assignment, double leaf_value) { + return leaf_value + dc->error(DiscreteValues(assignment)); + }); + } } } @@ -225,22 +229,27 @@ AlgebraicDecisionTree HybridBayesNet::errorTree( } /* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::logProbability( +AlgebraicDecisionTree HybridBayesNet::logDiscretePosteriorPrime( const VectorValues &continuousValues) const { AlgebraicDecisionTree result(0.0); + // Get logProbability function for a conditional or arbitrarily small + // logProbability if the conditional was pruned out. + auto probFunc = [continuousValues]( + const GaussianConditional::shared_ptr &conditional) { + return conditional ? conditional->logProbability(continuousValues) : -1e20; + }; + // Iterate over each conditional. for (auto &&conditional : *this) { if (auto gm = conditional->asHybrid()) { // If conditional is hybrid, select based on assignment and compute // logProbability. - result = result + gm->logProbability(continuousValues); + result = result + DecisionTree(gm->conditionals(), probFunc); } else if (auto gc = conditional->asGaussian()) { - // If continuous, get the (double) logProbability and add it to the - // result + // If continuous, get the logProbability and add it to the result double logProbability = gc->logProbability(continuousValues); - // Add the computed logProbability to every leaf of the logProbability - // tree. + // Add the computed logProbability to every leaf of the tree. result = result.apply([logProbability](double leaf_value) { return leaf_value + logProbability; }); @@ -261,10 +270,13 @@ AlgebraicDecisionTree HybridBayesNet::logProbability( } /* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::evaluate( +AlgebraicDecisionTree HybridBayesNet::discretePosterior( const VectorValues &continuousValues) const { - AlgebraicDecisionTree tree = this->logProbability(continuousValues); - return tree.apply([](double log) { return exp(log); }); + AlgebraicDecisionTree log_p = + this->logDiscretePosteriorPrime(continuousValues); + AlgebraicDecisionTree p = + log_p.apply([](double log) { return exp(log); }); + return p / p.sum(); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 9052a7a167..9e621ea205 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -125,12 +125,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /** * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete - * value assignment. + * value assignment. Note this corresponds to the Gaussian posterior p(X|M=m) + * of the continuous variables given the discrete assignment M=m. * * @note Any pure discrete factors are ignored. * * @param assignment The discrete value assignment for the discrete keys. - * @return GaussianBayesNet + * @return Gaussian posterior as a GaussianBayesNet */ GaussianBayesNet choose(const DiscreteValues &assignment) const; @@ -226,29 +227,33 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using Base::error; /** - * @brief Compute log probability for each discrete assignment, - * and return as a tree. + * @brief Compute the log posterior log P'(M|x) of all assignments up to a + * constant, returning the result as an algebraic decision tree. + * + * @note The joint P(X,M) is p(X|M) P(M) + * Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x). + * Ideally we want log P(M|x) = log p(x|M) + log P(M) - log P(x), but + * unfortunately log p(x) is expensive, so we compute the log of the + * unnormalized posterior log P'(M|x) = log p(x|M) + log P(M) * - * @param continuousValues Continuous values at which - * to compute the log probability. + * @param continuousValues Continuous values x at which to compute log P'(M|x) * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree logProbability( + AlgebraicDecisionTree logDiscretePosteriorPrime( const VectorValues &continuousValues) const; using BayesNet::logProbability; // expose HybridValues version /** - * @brief Compute unnormalized probability q(μ|M), - * for each discrete assignment, and return as a tree. - * q(μ|M) is the unnormalized probability at the MLE point μ, - * conditioned on the discrete variables. + * @brief Compute normalized posterior P(M|X=x) and return as a tree. + * + * @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys, + * which we would need, are hard to recover. * - * @param continuousValues Continuous values at which to compute the - * probability. + * @param continuousValues Continuous values x to condition P(M|X=x) on. * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree evaluate( + AlgebraicDecisionTree discretePosterior( const VectorValues &continuousValues) const; /** diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 79979ac83a..8988d1e626 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -65,8 +65,7 @@ TEST(HybridBayesNet, Add) { // Test API for a pure discrete Bayes net P(Asia). TEST(HybridBayesNet, EvaluatePureDiscrete) { HybridBayesNet bayesNet; - const auto pAsia = std::make_shared(Asia, "4/6"); - bayesNet.push_back(pAsia); + bayesNet.emplace_shared(Asia, "4/6"); HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}}; // choose @@ -87,92 +86,39 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { EXPECT(assert_equal(one, bayesNet.sample(one, &rng))); EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng))); + // prune + EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); + // EXPECT(assert_equal(bayesNet, bayesNet.prune(1))); Should fail !!! + // EXPECT(assert_equal(bayesNet, bayesNet.prune(0))); Should fail !!! + + // errorTree + AlgebraicDecisionTree actual = bayesNet.errorTree({}); + AlgebraicDecisionTree expected( + {Asia}, std::vector{-log(0.4), -log(0.6)}); + EXPECT(assert_equal(expected, actual)); + // error EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); - - // logProbability - EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); - EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); - - // toFactorGraph - HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({}); - EXPECT(assert_equal(expectedFG, fg)); - - // prune, imperative :-( - EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); - EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size()); } /* ****************************************************************************/ // Test creation of a tiny hybrid Bayes net. TEST(HybridBayesNet, Tiny) { - auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode) - EXPECT_LONGS_EQUAL(3, bayesNet.size()); + auto bn = tiny::createHybridBayesNet(); + EXPECT_LONGS_EQUAL(3, bn.size()); const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}}; - HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}}; - - // Check Invariants for components - HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid(); - GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()), - gc1 = hgc->choose(one.discrete()); - GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian(); - GaussianConditional::CheckInvariants(*gc0, vv); - GaussianConditional::CheckInvariants(*gc1, vv); - GaussianConditional::CheckInvariants(*px, vv); - HybridGaussianConditional::CheckInvariants(*hgc, zero); - HybridGaussianConditional::CheckInvariants(*hgc, one); - - // choose - GaussianBayesNet expectedChosen; - expectedChosen.push_back(gc0); - expectedChosen.push_back(px); - auto chosen0 = bayesNet.choose(zero.discrete()); - auto chosen1 = bayesNet.choose(one.discrete()); - EXPECT(assert_equal(expectedChosen, chosen0, 1e-9)); - - // logProbability - const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior - const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior - EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9); - EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9); - - // evaluate - EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9); - - // optimize - EXPECT(assert_equal(one, bayesNet.optimize())); - EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete()))); - - // sample - std::mt19937_64 rng(42); - EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete())); - - // error - const double error0 = chosen0.error(vv) + gc0->negLogConstant() - - px->negLogConstant() - log(0.4); - const double error1 = chosen1.error(vv) + gc1->negLogConstant() - - px->negLogConstant() - log(0.6); - EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9); - EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9); - EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9); - - // toFactorGraph - auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}}); + auto fg = bn.toFactorGraph(vv); EXPECT_LONGS_EQUAL(3, fg.size()); // Check that the ratio of probPrime to evaluate is the same for all modes. std::vector ratio(2); - ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); - ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); + for (size_t mode : {0, 1}) { + const HybridValues hv{vv, {{M(0), mode}}}; + ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv); + } EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); - - // prune, imperative :-( - auto pruned = bayesNet.prune(1); - EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); - EXPECT(!pruned.equals(bayesNet)); - } /* ****************************************************************************/ @@ -318,22 +264,19 @@ TEST(HybridBayesNet, Pruning) { // Optimize HybridValues delta = posterior->optimize(); - auto actualTree = posterior->evaluate(delta.continuous()); - // Regression test on density tree. - std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; - std::vector leaves = {6.1112424, 20.346113, 17.785849, 19.738098}; - AlgebraicDecisionTree expected(discrete_keys, leaves); - EXPECT(assert_equal(expected, actualTree, 1e-6)); + // Verify discrete posterior at optimal value sums to 1. + auto discretePosterior = posterior->discretePosterior(delta.continuous()); + EXPECT_DOUBLES_EQUAL(1.0, discretePosterior.sum(), 1e-9); + + // Regression test on discrete posterior at optimal value. + std::vector leaves = {0.095516068, 0.31800092, 0.27798511, 0.3084979}; + AlgebraicDecisionTree expected(s.modes, leaves); + EXPECT(assert_equal(expected, discretePosterior, 1e-6)); // Prune and get probabilities auto prunedBayesNet = posterior->prune(2); - auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); - - // Regression test on pruned logProbability tree - std::vector pruned_leaves = {0.0, 32.713418, 0.0, 31.735823}; - AlgebraicDecisionTree expected_pruned(discrete_keys, pruned_leaves); - EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); + auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous()); // Verify logProbability computation and check specific logProbability value const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; @@ -346,14 +289,21 @@ TEST(HybridBayesNet, Pruning) { posterior->at(3)->asDiscrete()->logProbability(hybridValues); logProbability += posterior->at(4)->asDiscrete()->logProbability(hybridValues); - - // Regression - double density = exp(logProbability); - EXPECT_DOUBLES_EQUAL(density, - 1.6078460548731697 * actualTree(discrete_values), 1e-6); - EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), 1e-9); + + // Check agreement with discrete posterior + // double density = exp(logProbability); + // FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), + // 1e-6); + + // Regression test on pruned logProbability tree + std::vector pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578}; + AlgebraicDecisionTree expected_pruned(s.modes, pruned_leaves); + EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); + + // Regression + // FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); } /* ****************************************************************************/ From 3d55fe0d378dff654d8520e9d174b93be22948d5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 22:56:32 -0700 Subject: [PATCH 09/22] Finish tests --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 36 ++++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 8988d1e626..ee47a698a5 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -65,13 +65,18 @@ TEST(HybridBayesNet, Add) { // Test API for a pure discrete Bayes net P(Asia). TEST(HybridBayesNet, EvaluatePureDiscrete) { HybridBayesNet bayesNet; - bayesNet.emplace_shared(Asia, "4/6"); + const auto pAsia = std::make_shared(Asia, "4/6"); + bayesNet.push_back(pAsia); HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}}; // choose GaussianBayesNet empty; EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9)); + // logProbability + EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); + // evaluate EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9); @@ -88,18 +93,35 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { // prune EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); - // EXPECT(assert_equal(bayesNet, bayesNet.prune(1))); Should fail !!! - // EXPECT(assert_equal(bayesNet, bayesNet.prune(0))); Should fail !!! + EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size()); // errorTree AlgebraicDecisionTree actual = bayesNet.errorTree({}); - AlgebraicDecisionTree expected( + AlgebraicDecisionTree expectedErrorTree( {Asia}, std::vector{-log(0.4), -log(0.6)}); - EXPECT(assert_equal(expected, actual)); + EXPECT(assert_equal(expectedErrorTree, actual)); // error EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); + + // logDiscretePosteriorPrime, TODO: useless as -errorTree? + AlgebraicDecisionTree expected({Asia}, + std::vector{log(0.4), log(0.6)}); + EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime({}))); + + // logProbability + EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); + + // discretePosterior + AlgebraicDecisionTree expectedPosterior({Asia}, + std::vector{0.4, 0.6}); + EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({}))); + + // toFactorGraph + HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({}); + EXPECT(assert_equal(expectedFG, fg)); } /* ****************************************************************************/ @@ -358,7 +380,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials); // Prune! - posterior->prune(maxNrLeaves); + auto pruned = posterior->prune(maxNrLeaves); // Functor to verify values against the expected_discrete_conditionals auto checker = [&](const Assignment& assignment, @@ -375,7 +397,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { }; // Get the pruned discrete conditionals as an AlgebraicDecisionTree - auto pruned_discrete_conditionals = posterior->at(4)->asDiscrete(); + auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete(); auto discrete_conditional_tree = std::dynamic_pointer_cast( pruned_discrete_conditionals); From 50809001e161e2886e7acf9c776b79b94b3aa9a5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 00:20:10 -0700 Subject: [PATCH 10/22] Got rid of HBN::errorTree. Weird semantics and not used unless in regression tests. --- gtsam/hybrid/HybridBayesNet.cpp | 34 ------- gtsam/hybrid/HybridBayesNet.h | 10 -- gtsam/hybrid/tests/testHybridBayesNet.cpp | 98 ++++++++++++------- .../hybrid/tests/testHybridGaussianFactor.cpp | 7 -- .../tests/testHybridNonlinearFactorGraph.cpp | 7 -- 5 files changed, 65 insertions(+), 91 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 5f655c9902..9df0012c7e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -194,40 +194,6 @@ HybridValues HybridBayesNet::sample() const { return sample(&kRandomNumberGenerator); } -/* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::errorTree( - const VectorValues &continuousValues) const { - AlgebraicDecisionTree result(0.0); - - // Iterate over each conditional. - for (auto &&conditional : *this) { - if (auto gm = conditional->asHybrid()) { - // If conditional is hybrid, compute error for all assignments. - result = result + gm->errorTree(continuousValues); - - } else if (auto gc = conditional->asGaussian()) { - // If continuous, get the error and add it to the result - double error = gc->error(continuousValues); - // Add the computed error to every leaf of the result tree. - result = result.apply( - [error](double leaf_value) { return leaf_value + error; }); - - } else if (auto dc = conditional->asDiscrete()) { - // If discrete, add the discrete error in the right branch - if (result.nrLeaves() == 1) { - result = dc->errorTree(); - } else { - result = result.apply( - [dc](const Assignment &assignment, double leaf_value) { - return leaf_value + dc->error(DiscreteValues(assignment)); - }); - } - } - } - - return result; -} - /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::logDiscretePosteriorPrime( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 9e621ea205..fba6bb6aa8 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -210,16 +210,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ HybridBayesNet prune(size_t maxNrLeaves) const; - /** - * @brief Compute conditional error for each discrete assignment, - * and return as a tree. - * - * @param continuousValues Continuous values at which to compute the error. - * @return AlgebraicDecisionTree - */ - AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const; - /** * @brief Error method using HybridValues which returns specific error for * assignment. diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index ee47a698a5..9974827e80 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -95,12 +95,6 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size()); - // errorTree - AlgebraicDecisionTree actual = bayesNet.errorTree({}); - AlgebraicDecisionTree expectedErrorTree( - {Asia}, std::vector{-log(0.4), -log(0.6)}); - EXPECT(assert_equal(expectedErrorTree, actual)); - // error EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); @@ -127,20 +121,73 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { /* ****************************************************************************/ // Test creation of a tiny hybrid Bayes net. TEST(HybridBayesNet, Tiny) { - auto bn = tiny::createHybridBayesNet(); - EXPECT_LONGS_EQUAL(3, bn.size()); + auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode) + EXPECT_LONGS_EQUAL(3, bayesNet.size()); const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}}; - auto fg = bn.toFactorGraph(vv); - EXPECT_LONGS_EQUAL(3, fg.size()); - - // Check that the ratio of probPrime to evaluate is the same for all modes. - std::vector ratio(2); - for (size_t mode : {0, 1}) { - const HybridValues hv{vv, {{M(0), mode}}}; - ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv); - } - EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); + HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}}; + + // choose + HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid(); + GaussianBayesNet chosen; + chosen.push_back(hgc->choose(zero.discrete())); + chosen.push_back(bayesNet.at(1)->asGaussian()); + EXPECT(assert_equal(chosen, bayesNet.choose(zero.discrete()), 1e-9)); + + // logProbability + const double logP0 = chosen.logProbability(vv) + log(0.4); // 0.4 is prior + const double logP1 = bayesNet.choose(one.discrete()).logProbability(vv) + log(0.6); // 0.6 is prior + EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9); + + // evaluate + EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9); + + // optimize + EXPECT(assert_equal(one, bayesNet.optimize())); + EXPECT(assert_equal(chosen.optimize(), bayesNet.optimize(zero.discrete()))); + + // sample + std::mt19937_64 rng(42); + EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete())); + + // prune + auto pruned = bayesNet.prune(1); + EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); + EXPECT(!pruned.equals(bayesNet)); + + // // error + // EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); + // EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); + + // logDiscretePosteriorPrime, TODO: useless as -errorTree? + AlgebraicDecisionTree expected(M(0), logP0, logP1); + EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime(vv))); + + // // logProbability + // EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); + // EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); + + // // discretePosterior + // AlgebraicDecisionTree expectedPosterior({Asia}, + // std::vector{0.4, + // 0.6}); + // EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({}))); + + // // toFactorGraph + // HybridGaussianFactorGraph expectedFG{}; + + // auto fg = bayesNet.toFactorGraph(vv); + // EXPECT_LONGS_EQUAL(3, fg.size()); + // EXPECT(assert_equal(expectedFG, fg)); + + // // Check that the ratio of probPrime to evaluate is the same for all modes. + // std::vector ratio(2); + // ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); + // ratio[0] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); + // EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); + + // TODO: better test: check if discretePosteriors agree ! } /* ****************************************************************************/ @@ -174,21 +221,6 @@ TEST(HybridBayesNet, evaluateHybrid) { bayesNet.evaluate(values), 1e-9); } -/* ****************************************************************************/ -// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia). -TEST(HybridBayesNet, Error) { - using namespace different_sigmas; - - AlgebraicDecisionTree actual = bayesNet.errorTree(values.continuous()); - - // Regression. - // Manually added all the error values from the 3 conditional types. - AlgebraicDecisionTree expected( - {Asia}, std::vector{2.33005033585, 5.38619084965}); - - EXPECT(assert_equal(expected, actual)); -} - /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { diff --git a/gtsam/hybrid/tests/testHybridGaussianFactor.cpp b/gtsam/hybrid/tests/testHybridGaussianFactor.cpp index c2ffe24c89..5ff8c14781 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactor.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactor.cpp @@ -357,16 +357,9 @@ TEST(HybridGaussianFactor, DifferentCovariancesFG) { cv.insert(X(0), Vector1(0.0)); cv.insert(X(1), Vector1(0.0)); - // Check that the error values at the MLE point μ. - AlgebraicDecisionTree errorTree = hbn->errorTree(cv); - DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv1{{M(1), 1}}; - // regression - EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9); - EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9); - DiscreteConditional expected_m1(m1, "0.5/0.5"); DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 4735c16573..3a26f44869 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -994,16 +994,9 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) { cv.insert(X(0), Vector1(0.0)); cv.insert(X(1), Vector1(0.0)); - // Check that the error values at the MLE point μ. - AlgebraicDecisionTree errorTree = hbn->errorTree(cv); - DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv1{{M(1), 1}}; - // regression - EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9); - EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9); - DiscreteConditional expected_m1(m1, "0.5/0.5"); DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); From 78b47770c0c0b783de6081ea5047fe275ea851df Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 01:11:18 -0700 Subject: [PATCH 11/22] All tests for tiny work --- gtsam/hybrid/HybridGaussianConditional.h | 1 + gtsam/hybrid/tests/testHybridBayesNet.cpp | 79 ++++++++++++----------- 2 files changed, 42 insertions(+), 38 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index f3bf4d839e..27e31d7670 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -14,6 +14,7 @@ * @brief A hybrid conditional in the Conditional Linear Gaussian scheme * @author Fan Jiang * @author Varun Agrawal + * @author Frank Dellaert * @date Mar 12, 2022 */ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 9974827e80..ea9ca8285e 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -73,10 +73,6 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { GaussianBayesNet empty; EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9)); - // logProbability - EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); - EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); - // evaluate EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9); @@ -127,16 +123,28 @@ TEST(HybridBayesNet, Tiny) { const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}}; HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}}; - // choose + // Check Invariants for components HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid(); - GaussianBayesNet chosen; - chosen.push_back(hgc->choose(zero.discrete())); - chosen.push_back(bayesNet.at(1)->asGaussian()); - EXPECT(assert_equal(chosen, bayesNet.choose(zero.discrete()), 1e-9)); + GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()), + gc1 = hgc->choose(one.discrete()); + GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian(); + GaussianConditional::CheckInvariants(*gc0, vv); + GaussianConditional::CheckInvariants(*gc1, vv); + GaussianConditional::CheckInvariants(*px, vv); + HybridGaussianConditional::CheckInvariants(*hgc, zero); + HybridGaussianConditional::CheckInvariants(*hgc, one); + + // choose + GaussianBayesNet expectedChosen; + expectedChosen.push_back(gc0); + expectedChosen.push_back(px); + auto chosen0 = bayesNet.choose(zero.discrete()); + auto chosen1 = bayesNet.choose(one.discrete()); + EXPECT(assert_equal(expectedChosen, chosen0, 1e-9)); // logProbability - const double logP0 = chosen.logProbability(vv) + log(0.4); // 0.4 is prior - const double logP1 = bayesNet.choose(one.discrete()).logProbability(vv) + log(0.6); // 0.6 is prior + const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior + const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9); EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9); @@ -145,7 +153,7 @@ TEST(HybridBayesNet, Tiny) { // optimize EXPECT(assert_equal(one, bayesNet.optimize())); - EXPECT(assert_equal(chosen.optimize(), bayesNet.optimize(zero.discrete()))); + EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete()))); // sample std::mt19937_64 rng(42); @@ -156,38 +164,33 @@ TEST(HybridBayesNet, Tiny) { EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); EXPECT(!pruned.equals(bayesNet)); - // // error - // EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); - // EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); + // error + const double error0 = chosen0.error(vv) + gc0->negLogConstant() - + px->negLogConstant() - log(0.4); + const double error1 = chosen1.error(vv) + gc1->negLogConstant() - + px->negLogConstant() - log(0.6); + EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9); + EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9); // logDiscretePosteriorPrime, TODO: useless as -errorTree? AlgebraicDecisionTree expected(M(0), logP0, logP1); EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime(vv))); - // // logProbability - // EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); - // EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); - - // // discretePosterior - // AlgebraicDecisionTree expectedPosterior({Asia}, - // std::vector{0.4, - // 0.6}); - // EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({}))); - - // // toFactorGraph - // HybridGaussianFactorGraph expectedFG{}; - - // auto fg = bayesNet.toFactorGraph(vv); - // EXPECT_LONGS_EQUAL(3, fg.size()); - // EXPECT(assert_equal(expectedFG, fg)); - - // // Check that the ratio of probPrime to evaluate is the same for all modes. - // std::vector ratio(2); - // ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); - // ratio[0] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); - // EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); + // discretePosterior + double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1; + AlgebraicDecisionTree expectedPosterior(M(0), q0 / sum, q1 / sum); + EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv))); - // TODO: better test: check if discretePosteriors agree ! + // toFactorGraph + auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}}); + EXPECT_LONGS_EQUAL(3, fg.size()); + + // Check that the ratio of probPrime to evaluate is the same for all modes. + std::vector ratio(2); + ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); + ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); + EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); } /* ****************************************************************************/ From d054a041ed446f4d28e250befdebb9e4e37cb6e1 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 14:42:33 -0700 Subject: [PATCH 12/22] choose docs --- gtsam/hybrid/HybridBayesNet.h | 2 +- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- gtsam/hybrid/HybridGaussianFactorGraph.h | 19 +++++++++++++++++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index fba6bb6aa8..94a0762def 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -128,7 +128,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * value assignment. Note this corresponds to the Gaussian posterior p(X|M=m) * of the continuous variables given the discrete assignment M=m. * - * @note Any pure discrete factors are ignored. + * @note Be careful, as any factors not Gaussian are ignored. * * @param assignment The discrete value assignment for the discrete keys. * @return Gaussian posterior as a GaussianBayesNet diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 0e5a34359e..0c4e9c4897 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -542,7 +542,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::discretePosterior( } /* ************************************************************************ */ -GaussianFactorGraph HybridGaussianFactorGraph::operator()( +GaussianFactorGraph HybridGaussianFactorGraph::choose( const DiscreteValues &assignment) const { GaussianFactorGraph gfg; for (auto &&f : *this) { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 3ef6218bec..a5130ca086 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -230,8 +230,23 @@ class GTSAM_EXPORT HybridGaussianFactorGraph eliminate(const Ordering& keys) const; /// @} - /// Get the GaussianFactorGraph at a given discrete assignment. - GaussianFactorGraph operator()(const DiscreteValues& assignment) const; + /** + @brief Get the GaussianFactorGraph at a given discrete assignment. Note this + * corresponds to the Gaussian posterior p(X|M=m, Z=z) of the continuous + * variables X given the discrete assignment M=m and whatever measurements z + * where assumed in the creation of the factor Graph. + * + * @note Be careful, as any factors not Gaussian are ignored. + * + * @param assignment The discrete value assignment for the discrete keys. + * @return Gaussian factors as a GaussianFactorGraph + */ + GaussianFactorGraph choose(const DiscreteValues& assignment) const; + + /// Syntactic sugar for choose + GaussianFactorGraph operator()(const DiscreteValues& assignment) const { + return choose(assignment); + } }; // traits From 5fb3b377718253fb6151e40dda7def95bfe2bb37 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 15:19:05 -0700 Subject: [PATCH 13/22] Additional arithmetic --- gtsam/discrete/AlgebraicDecisionTree.h | 11 ++++ .../tests/testAlgebraicDecisionTree.cpp | 54 +++++++++++-------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 17385a975d..6001b1983d 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -70,6 +70,7 @@ namespace gtsam { return a / b; } static inline double id(const double& x) { return x; } + static inline double negate(const double& x) { return -x; } }; AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {} @@ -186,6 +187,16 @@ namespace gtsam { return this->apply(g, &Ring::add); } + /** negation */ + AlgebraicDecisionTree operator-() const { + return this->apply(&Ring::negate); + } + + /** subtract */ + AlgebraicDecisionTree operator-(const AlgebraicDecisionTree& g) const { + return *this + (-g); + } + /** product */ AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const { return this->apply(g, &Ring::mul); diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index ffb1f0b5ac..bf728695c9 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -10,10 +10,10 @@ * -------------------------------------------------------------------------- */ /* - * @file testDecisionTree.cpp - * @brief Develop DecisionTree - * @author Frank Dellaert - * @date Mar 6, 2011 + * @file testAlgebraicDecisionTree.cpp + * @brief Unit tests for Algebraic decision tree + * @author Frank Dellaert + * @date Mar 6, 2011 */ #include @@ -46,23 +46,35 @@ void dot(const T& f, const string& filename) { #endif } -/** I can't get this to work ! - class Mul: std::function { - inline double operator()(const double& a, const double& b) { - return a * b; - } - }; - - // If second argument of binary op is Leaf - template - typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const { - Ptr h(new Choice(label(), cardinality())); - for(const NodePtr& branch: branches_) - h->push_back(branch->apply_f_op_g(cache, gL, op)); - return Unique(cache, h); - } - */ +/* ************************************************************************** */ +// Test arithmetic: +TEST(ADT, arithmetic) { + DiscreteKey A(0, 2), B(1, 2); + ADT zero{0}, one{1}; + ADT a(A, 1, 2); + ADT b(B, 3, 4); + + // Addition + CHECK(assert_equal(a, zero + a)); + + // Negate and subtraction + CHECK(assert_equal(-a, zero - a)); + CHECK(assert_equal({zero}, a - a)); + CHECK(assert_equal(a + b, b + a)); + CHECK(assert_equal({A, 3, 4}, a + 2)); + CHECK(assert_equal({B, 1, 2}, b - 2)); + + // Multiplication + CHECK(assert_equal(zero, zero * a)); + CHECK(assert_equal(zero, a * zero)); + CHECK(assert_equal(a, one * a)); + CHECK(assert_equal(a, a * one)); + CHECK(assert_equal(a * b, b * a)); + + // division + // CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning + CHECK(assert_equal(b, (a * b) / a)); +} /* ************************************************************************** */ // instrumented operators From 53599969ad20b73bfa13771a04d64a59625c88ca Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 16:20:50 -0700 Subject: [PATCH 14/22] FIX BUG in errorTree --- gtsam/hybrid/HybridConditional.cpp | 68 ++++++++++++------------------ 1 file changed, 28 insertions(+), 40 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index cac2adcf87..175aec30c7 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -64,7 +64,6 @@ void HybridConditional::print(const std::string &s, if (inner_) { inner_->print("", formatter); - } else { if (isContinuous()) std::cout << "Continuous "; if (isDiscrete()) std::cout << "Discrete "; @@ -100,79 +99,68 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const { if (auto gm = asHybrid()) { auto other = e->asHybrid(); return other != nullptr && gm->equals(*other, tol); - } - if (auto gc = asGaussian()) { + } else if (auto gc = asGaussian()) { auto other = e->asGaussian(); return other != nullptr && gc->equals(*other, tol); - } - if (auto dc = asDiscrete()) { + } else if (auto dc = asDiscrete()) { auto other = e->asDiscrete(); return other != nullptr && dc->equals(*other, tol); - } - - return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) - : !(e->inner_); + } else + return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) + : !(e->inner_); } /* ************************************************************************ */ double HybridConditional::error(const HybridValues &values) const { if (auto gc = asGaussian()) { return gc->error(values.continuous()); - } - if (auto gm = asHybrid()) { + } else if (auto gm = asHybrid()) { return gm->error(values); - } - if (auto dc = asDiscrete()) { + } else if (auto dc = asDiscrete()) { return dc->error(values.discrete()); - } - throw std::runtime_error( - "HybridConditional::error: conditional type not handled"); + } else + throw std::runtime_error( + "HybridConditional::error: conditional type not handled"); } /* ************************************************************************ */ AlgebraicDecisionTree HybridConditional::errorTree( const VectorValues &values) const { if (auto gc = asGaussian()) { - return AlgebraicDecisionTree(gc->error(values)); - } - if (auto gm = asHybrid()) { + return {gc->error(values)}; // NOTE: a "constant" tree + } else if (auto gm = asHybrid()) { return gm->errorTree(values); - } - if (auto dc = asDiscrete()) { - return AlgebraicDecisionTree(0.0); - } - throw std::runtime_error( - "HybridConditional::error: conditional type not handled"); + } else if (auto dc = asDiscrete()) { + return dc->errorTree(); + } else + throw std::runtime_error( + "HybridConditional::error: conditional type not handled"); } /* ************************************************************************ */ double HybridConditional::logProbability(const HybridValues &values) const { if (auto gc = asGaussian()) { return gc->logProbability(values.continuous()); - } - if (auto gm = asHybrid()) { + } else if (auto gm = asHybrid()) { return gm->logProbability(values); - } - if (auto dc = asDiscrete()) { + } else if (auto dc = asDiscrete()) { return dc->logProbability(values.discrete()); - } - throw std::runtime_error( - "HybridConditional::logProbability: conditional type not handled"); + } else + throw std::runtime_error( + "HybridConditional::logProbability: conditional type not handled"); } /* ************************************************************************ */ double HybridConditional::negLogConstant() const { if (auto gc = asGaussian()) { return gc->negLogConstant(); - } - if (auto gm = asHybrid()) { - return gm->negLogConstant(); // 0.0! - } - if (auto dc = asDiscrete()) { + } else if (auto gm = asHybrid()) { + return gm->negLogConstant(); + } else if (auto dc = asDiscrete()) { return dc->negLogConstant(); // 0.0! - } - throw std::runtime_error( - "HybridConditional::negLogConstant: conditional type not handled"); + } else + throw std::runtime_error( + "HybridConditional::negLogConstant: conditional type not handled"); } /* ************************************************************************ */ From 3b50ba9895b132db2e02ff41220ef29241dcc21f Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 16:21:24 -0700 Subject: [PATCH 15/22] FIX BUG: don't skip discrete factors! --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 14 +++++++------- .../tests/testHybridGaussianFactorGraph.cpp | 16 ++++++---------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 0c4e9c4897..7dfa56e77d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -508,16 +508,16 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( AlgebraicDecisionTree result(0.0); // Iterate over each factor. for (auto &factor : factors_) { - if (auto f = std::dynamic_pointer_cast(factor)) { - // Check for HybridFactor, and call errorTree - result = result + f->errorTree(continuousValues); - } else if (auto f = std::dynamic_pointer_cast(factor)) { - // Skip discrete factors - continue; + if (auto hf = std::dynamic_pointer_cast(factor)) { + // Add errorTree for hybrid factors, includes HybridGaussianConditionals! + result = result + hf->errorTree(continuousValues); + } else if (auto df = std::dynamic_pointer_cast(factor)) { + // If discrete, just add its errorTree as well + result = result + df->errorTree(); } else { // Everything else is a continuous only factor HybridValues hv(continuousValues, DiscreteValues()); - result = result + AlgebraicDecisionTree(factor->error(hv)); + result = result + factor->error(hv); // NOTE: yes, you can add constants } } return result; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 0c5f52e611..f30085f020 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -614,21 +614,20 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { const HybridValues delta = hybridBayesNet->optimize(); // regression test for errorTree - std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; + std::vector leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947}; AlgebraicDecisionTree expectedErrors(s.modes, leaves); const auto error_tree = graph.errorTree(delta.continuous()); EXPECT(assert_equal(expectedErrors, error_tree, 1e-7)); // regression test for discretePosterior const AlgebraicDecisionTree expectedPosterior( - s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852}); + s.modes, std::vector{0.095516068, 0.31800092, 0.27798511, 0.3084979}); auto posterior = graph.discretePosterior(delta.continuous()); EXPECT(assert_equal(expectedPosterior, posterior, 1e-7)); } /* ****************************************************************************/ -// Test hybrid gaussian factor graph errorTree during -// incremental operation +// Test hybrid gaussian factor graph errorTree during incremental operation TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { Switching s(4); @@ -648,8 +647,7 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { auto error_tree = graph.errorTree(delta.continuous()); std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; - std::vector leaves = {0.99985581, 0.4902432, 0.51936941, - 0.0097568009}; + std::vector leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947}; AlgebraicDecisionTree expected_error(discrete_keys, leaves); // regression @@ -666,12 +664,10 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { delta = hybridBayesNet->optimize(); auto error_tree2 = graph.errorTree(delta.continuous()); - discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; + // regression leaves = {0.50985198, 0.0097577296, 0.50009425, 0, 0.52922138, 0.029127133, 0.50985105, 0.0097567964}; - AlgebraicDecisionTree expected_error2(discrete_keys, leaves); - - // regression + AlgebraicDecisionTree expected_error2(s.modes, leaves); EXPECT(assert_equal(expected_error, error_tree, 1e-7)); } From d77efb0f51824db728b18b0030e53c3ad7edabf7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 16:22:02 -0700 Subject: [PATCH 16/22] Drastically simplify errorTree --- gtsam/hybrid/HybridBayesNet.cpp | 37 +++------------------ gtsam/hybrid/HybridBayesNet.h | 6 ++-- gtsam/hybrid/tests/testHybridBayesNet.cpp | 39 ++++++++++++++++------- 3 files changed, 35 insertions(+), 47 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 9df0012c7e..b4441f15a7 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -195,41 +195,13 @@ HybridValues HybridBayesNet::sample() const { } /* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::logDiscretePosteriorPrime( +AlgebraicDecisionTree HybridBayesNet::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree result(0.0); - // Get logProbability function for a conditional or arbitrarily small - // logProbability if the conditional was pruned out. - auto probFunc = [continuousValues]( - const GaussianConditional::shared_ptr &conditional) { - return conditional ? conditional->logProbability(continuousValues) : -1e20; - }; - // Iterate over each conditional. for (auto &&conditional : *this) { - if (auto gm = conditional->asHybrid()) { - // If conditional is hybrid, select based on assignment and compute - // logProbability. - result = result + DecisionTree(gm->conditionals(), probFunc); - } else if (auto gc = conditional->asGaussian()) { - // If continuous, get the logProbability and add it to the result - double logProbability = gc->logProbability(continuousValues); - // Add the computed logProbability to every leaf of the tree. - result = result.apply([logProbability](double leaf_value) { - return leaf_value + logProbability; - }); - } else if (auto dc = conditional->asDiscrete()) { - // If discrete, add the discrete logProbability in the right branch - if (result.nrLeaves() == 1) { - result = dc->errorTree().apply([](double error) { return -error; }); - } else { - result = result.apply([dc](const Assignment &assignment, - double leaf_value) { - return leaf_value + dc->logProbability(DiscreteValues(assignment)); - }); - } - } + result = result + conditional->errorTree(continuousValues); } return result; @@ -238,10 +210,9 @@ AlgebraicDecisionTree HybridBayesNet::logDiscretePosteriorPrime( /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::discretePosterior( const VectorValues &continuousValues) const { - AlgebraicDecisionTree log_p = - this->logDiscretePosteriorPrime(continuousValues); + AlgebraicDecisionTree errors = this->errorTree(continuousValues); AlgebraicDecisionTree p = - log_p.apply([](double log) { return exp(log); }); + errors.apply([](double error) { return exp(-error); }); return p / p.sum(); } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 94a0762def..a997174ecc 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,8 +217,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using Base::error; /** - * @brief Compute the log posterior log P'(M|x) of all assignments up to a - * constant, returning the result as an algebraic decision tree. + * @brief Compute the negative log posterior log P'(M|x) of all assignments up + * to a constant, returning the result as an algebraic decision tree. * * @note The joint P(X,M) is p(X|M) P(M) * Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x). @@ -229,7 +229,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param continuousValues Continuous values x at which to compute log P'(M|x) * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree logDiscretePosteriorPrime( + AlgebraicDecisionTree errorTree( const VectorValues &continuousValues) const; using BayesNet::logProbability; // expose HybridValues version diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index ea9ca8285e..521bca4a77 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -95,18 +95,16 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); - // logDiscretePosteriorPrime, TODO: useless as -errorTree? - AlgebraicDecisionTree expected({Asia}, - std::vector{log(0.4), log(0.6)}); - EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime({}))); + // errorTree + AlgebraicDecisionTree expected(asiaKey, -log(0.4), -log(0.6)); + EXPECT(assert_equal(expected, bayesNet.errorTree({}))); // logProbability EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); // discretePosterior - AlgebraicDecisionTree expectedPosterior({Asia}, - std::vector{0.4, 0.6}); + AlgebraicDecisionTree expectedPosterior(asiaKey, 0.4, 0.6); EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({}))); // toFactorGraph @@ -169,15 +167,21 @@ TEST(HybridBayesNet, Tiny) { px->negLogConstant() - log(0.4); const double error1 = chosen1.error(vv) + gc1->negLogConstant() - px->negLogConstant() - log(0.6); + // print errors: + std::cout << "error0 = " << error0 << std::endl; + std::cout << "error1 = " << error1 << std::endl; EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9); EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9); - // logDiscretePosteriorPrime, TODO: useless as -errorTree? - AlgebraicDecisionTree expected(M(0), logP0, logP1); - EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime(vv))); + // errorTree + AlgebraicDecisionTree expected(M(0), error0, error1); + EXPECT(assert_equal(expected, bayesNet.errorTree(vv))); // discretePosterior + // We have: P(z|x,mode)P(x)P(mode). When we condition on z and x, we get + // P(mode|z,x) \propto P(z|x,mode)P(x)P(mode) + // Normalizing this yields posterior P(mode|z,x) = {0.8, 0.2} double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1; AlgebraicDecisionTree expectedPosterior(M(0), q0 / sum, q1 / sum); EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv))); @@ -191,6 +195,19 @@ TEST(HybridBayesNet, Tiny) { ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); + + // TODO(Frank): Better test: check if discretePosteriors agree ! + // Since ϕ(M, x) \propto P(M,x|z) + // q0 = std::exp(-fg.error(zero)); + // q1 = std::exp(-fg.error(one)); + // sum = q0 + q1; + // AlgebraicDecisionTree fgPosterior(M(0), q0 / sum, q1 / sum); + VectorValues xv{{X(0), Vector1(5.0)}}; + fg.printErrors(zero); + fg.printErrors(one); + GTSAM_PRINT(fg.errorTree(xv)); + auto fgPosterior = fg.discretePosterior(xv); + EXPECT(assert_equal(expectedPosterior, fgPosterior)); } /* ****************************************************************************/ @@ -556,8 +573,8 @@ TEST(HybridBayesNet, ErrorTreeWithConditional) { AlgebraicDecisionTree errorTree = gfg.errorTree(vv); // regression - AlgebraicDecisionTree expected(m1, 59.335390372, 5050.125); - EXPECT(assert_equal(expected, errorTree, 1e-9)); + AlgebraicDecisionTree expected(m1, 60.028538, 5050.8181); + EXPECT(assert_equal(expected, errorTree, 1e-4)); } /* ************************************************************************* */ From 20e5664928297e865dd8610a47a55fe00d207de5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 16:23:39 -0700 Subject: [PATCH 17/22] Fix switching docs --- gtsam/hybrid/tests/Switching.h | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 547facce9a..1b176ad654 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -46,29 +46,29 @@ using symbol_shorthand::X; * @brief Create a switching system chain. A switching system is a continuous * system which depends on a discrete mode at each time step of the chain. * - * @param n The number of chain elements. + * @param K The number of chain elements. * @param x The functional to help specify the continuous key. * @param m The functional to help specify the discrete key. * @return HybridGaussianFactorGraph::shared_ptr */ inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain( - size_t n, std::function x = X, std::function m = M) { + size_t K, std::function x = X, std::function m = M) { HybridGaussianFactorGraph hfg; hfg.add(JacobianFactor(x(1), I_3x3, Z_3x1)); // x(1) to x(n+1) - for (size_t t = 1; t < n; t++) { - DiscreteKeys dKeys{{m(t), 2}}; + for (size_t k = 1; k < K; k++) { + DiscreteKeys dKeys{{m(k), 2}}; std::vector components; components.emplace_back( - new JacobianFactor(x(t), I_3x3, x(t + 1), I_3x3, Z_3x1)); + new JacobianFactor(x(k), I_3x3, x(k + 1), I_3x3, Z_3x1)); components.emplace_back( - new JacobianFactor(x(t), I_3x3, x(t + 1), I_3x3, Vector3::Ones())); - hfg.add(HybridGaussianFactor({m(t), 2}, components)); + new JacobianFactor(x(k), I_3x3, x(k + 1), I_3x3, Vector3::Ones())); + hfg.add(HybridGaussianFactor({m(k), 2}, components)); - if (t > 1) { - hfg.add(DecisionTreeFactor({{m(t - 1), 2}, {m(t), 2}}, "0 1 1 3")); + if (k > 1) { + hfg.add(DecisionTreeFactor({{m(k - 1), 2}, {m(k), 2}}, "0 1 1 3")); } } @@ -118,7 +118,7 @@ inline std::pair> makeBinaryOrdering( using MotionModel = BetweenFactor; // Test fixture with switching network. -/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(k),M(k+1)) +/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(K-3),M(K-2)) struct Switching { size_t K; DiscreteKeys modes; @@ -195,7 +195,7 @@ struct Switching { } /** - * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-1). + * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2). * E.g. if K=4, we want M0, M1 and M2. * * @param fg The factor graph to which the mode chain is added. From a709a2d750f3a1100650169e35865195293cb52f Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 16:52:41 -0700 Subject: [PATCH 18/22] Remove printing, add one more test --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 521bca4a77..f24c6fcb6d 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -168,8 +168,6 @@ TEST(HybridBayesNet, Tiny) { const double error1 = chosen1.error(vv) + gc1->negLogConstant() - px->negLogConstant() - log(0.6); // print errors: - std::cout << "error0 = " << error0 << std::endl; - std::cout << "error1 = " << error1 << std::endl; EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9); EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9); @@ -196,16 +194,13 @@ TEST(HybridBayesNet, Tiny) { ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); - // TODO(Frank): Better test: check if discretePosteriors agree ! - // Since ϕ(M, x) \propto P(M,x|z) - // q0 = std::exp(-fg.error(zero)); - // q1 = std::exp(-fg.error(one)); - // sum = q0 + q1; - // AlgebraicDecisionTree fgPosterior(M(0), q0 / sum, q1 / sum); + // Better and more general test: + // Since ϕ(M, x) \propto P(M,x|z) the discretePosteriors should agree + q0 = std::exp(-fg.error(zero)); + q1 = std::exp(-fg.error(one)); + sum = q0 + q1; + EXPECT(assert_equal(expectedPosterior, {M(0), q0 / sum, q1 / sum})); VectorValues xv{{X(0), Vector1(5.0)}}; - fg.printErrors(zero); - fg.printErrors(one); - GTSAM_PRINT(fg.errorTree(xv)); auto fgPosterior = fg.discretePosterior(xv); EXPECT(assert_equal(expectedPosterior, fgPosterior)); } From 5b713032c1afc4772501a8011d5cfffba824700e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 1 Oct 2024 11:31:16 -0700 Subject: [PATCH 19/22] Add test for prune --- .../tests/testHybridGaussianConditional.cpp | 58 ++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 24eb409a1c..cd9c182cd5 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -25,8 +25,12 @@ #include #include +#include #include +#include "gtsam/discrete/DecisionTree.h" +#include "gtsam/discrete/DiscreteKey.h" + // Include for test suite #include @@ -250,8 +254,60 @@ TEST(HybridGaussianConditional, Likelihood2) { } /* ************************************************************************* */ +// Test pruning a HybridGaussianConditional with two discrete keys, based on a +// DecisionTreeFactor with 3 keys: +TEST(HybridGaussianConditional, Prune) { + // Create a two key conditional: + DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; + std::vector gcs; + for (size_t i = 0; i < 4; i++) { + gcs.push_back( + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1)); + } + auto empty = std::make_shared(); + HybridGaussianConditional::Conditionals conditionals(modes, gcs); + HybridGaussianConditional hgc(modes, conditionals); + + DiscreteKeys keys = modes; + keys.push_back({M(3), 2}); + { + for (size_t i = 0; i < 8; i++) { + std::vector potentials{0, 0, 0, 0, 0, 0, 0, 0}; + potentials[i] = 1; + const DecisionTreeFactor decisionTreeFactor(keys, potentials); + // Prune the HybridGaussianConditional + const auto pruned = hgc.prune(decisionTreeFactor); + // Check that the pruned HybridGaussianConditional has 1 conditional + EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); + } + } + { + const std::vector potentials{0, 0, 0.5, 0, // + 0, 0, 0.5, 0}; + const DecisionTreeFactor decisionTreeFactor(keys, potentials); + + const auto pruned = hgc.prune(decisionTreeFactor); + + // Check that the pruned HybridGaussianConditional has 2 conditionals + EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); + } + { + const std::vector potentials{0.2, 0, 0.3, 0, // + 0, 0, 0.5, 0}; + const DecisionTreeFactor decisionTreeFactor(keys, potentials); + + const auto pruned = hgc.prune(decisionTreeFactor); + + // Check that the pruned HybridGaussianConditional has 3 conditionals + EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); + } +} + +/* ************************************************************************* + */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ +/* ************************************************************************* + */ From b70c63ee4cb2be9ec82601ff143af287259d4371 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 1 Oct 2024 13:32:23 -0700 Subject: [PATCH 20/22] Better prune --- gtsam/discrete/DecisionTreeFactor.h | 4 +- gtsam/hybrid/HybridGaussianConditional.cpp | 51 +++++++--------------- 2 files changed, 17 insertions(+), 38 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 784b11e518..7ed1160166 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -131,7 +131,7 @@ namespace gtsam { /// Calculate probability for given values `x`, /// is just look up in AlgebraicDecisionTree. - double evaluate(const DiscreteValues& values) const { + double evaluate(const Assignment& values) const { return ADT::operator()(values); } @@ -155,7 +155,7 @@ namespace gtsam { return apply(f, safe_div); } - /// Convert into a decisiontree + /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 1c3a69ce7d..2c0fb28a40 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -291,45 +291,24 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { /* *******************************************************************************/ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( const DecisionTreeFactor &discreteProbs) const { - auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); - auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys()); - - // Functional which loops over all assignments and create a set of - // GaussianConditionals + // Find keys in discreteProbs.keys() but not in this->keys(): + std::set mine(this->keys().begin(), this->keys().end()); + std::set theirs(discreteProbs.keys().begin(), + discreteProbs.keys().end()); + std::vector diff; + std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), + std::back_inserter(diff)); + + // Find maximum probability value for every combination of our keys. + Ordering keys(diff); + auto max = discreteProbs.max(keys); + + // Check the max value for every combination of our keys. + // If the max value is 0.0, we can prune the corresponding conditional. auto pruner = [&](const Assignment &choices, const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { - // typecast so we can use this to get probability value - const DiscreteValues values(choices); - - // Case where the hybrid gaussian conditional has the same - // discrete keys as the decision tree. - if (hybridGaussianCondKeySet == discreteProbsKeySet) { - return (discreteProbs(values) == 0.0) ? nullptr : conditional; - } else { - // TODO(Frank): It might be faster to "choose" based on values - // and then check whether the resulting tree has non-nullptrs. - std::vector set_diff; - std::set_difference( - discreteProbsKeySet.begin(), discreteProbsKeySet.end(), - hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(), - std::back_inserter(set_diff)); - - const std::vector assignments = - DiscreteValues::CartesianProduct(set_diff); - for (const DiscreteValues &assignment : assignments) { - DiscreteValues augmented_values(values); - augmented_values.insert(assignment); - - // If any one of the sub-branches are non-zero, - // we need this conditional. - if (discreteProbs(augmented_values) > 0.0) { - return conditional; - } - } - // If we are here, it means that all the sub-branches are 0, so we prune. - return nullptr; - } + return (max->evaluate(choices) == 0.0) ? nullptr : conditional; }; auto pruned_conditionals = conditionals_.apply(pruner); From cf9d38ef4fa81936e2e9dce175d40c8e3b24e8e0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 1 Oct 2024 13:32:41 -0700 Subject: [PATCH 21/22] better, functional prune --- gtsam/hybrid/HybridBayesNet.cpp | 90 ++++++++++------------- gtsam/hybrid/HybridBayesNet.h | 40 +++++----- gtsam/hybrid/tests/testHybridBayesNet.cpp | 45 ++++++------ 3 files changed, 80 insertions(+), 95 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b4441f15a7..36503d2eaf 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -17,10 +17,13 @@ */ #include +#include #include #include #include +#include + // In Wrappers we have no access to this so have a default ready static std::mt19937_64 kRandomNumberGenerator(42); @@ -38,48 +41,26 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { } /* ************************************************************************* */ -DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( - size_t maxNrLeaves) { - // Get the joint distribution of only the discrete keys - // The joint discrete probability. - DiscreteConditional discreteProbs; - - std::vector discrete_factor_idxs; - // Record frontal keys so we can maintain ordering - Ordering discrete_frontals; - - for (size_t i = 0; i < this->size(); i++) { - auto conditional = this->at(i); - if (conditional->isDiscrete()) { - discreteProbs = discreteProbs * (*conditional->asDiscrete()); - - Ordering conditional_keys(conditional->frontals()); - discrete_frontals += conditional_keys; - discrete_factor_idxs.push_back(i); - } - } - - const DecisionTreeFactor prunedDiscreteProbs = - discreteProbs.prune(maxNrLeaves); - - // Eliminate joint probability back into conditionals - DiscreteFactorGraph dfg{prunedDiscreteProbs}; - DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals); +// The implementation is: build the entire joint into one factor and then prune. +// TODO(Frank): This can be quite expensive *unless* the factors have already +// been pruned before. Another, possibly faster approach is branch and bound +// search to find the K-best leaves and then create a single pruned conditional. +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { + // Collect all the discrete conditionals. Could be small if already pruned. + const DiscreteBayesNet marginal = discreteMarginal(); - // Assign pruned discrete conditionals back at the correct indices. - for (size_t i = 0; i < discrete_factor_idxs.size(); i++) { - size_t idx = discrete_factor_idxs.at(i); - this->at(idx) = std::make_shared(dbn->at(i)); + // Multiply into one big conditional. NOTE: possibly quite expensive. + DiscreteConditional joint; + for (auto &&conditional : marginal) { + joint = joint * (*conditional); } - return prunedDiscreteProbs; -} + // Prune the joint. NOTE: again, possibly quite expensive. + const DecisionTreeFactor pruned = joint.prune(maxNrLeaves); -/* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { - HybridBayesNet copy(*this); - DecisionTreeFactor prunedDiscreteProbs = - copy.pruneDiscreteConditionals(maxNrLeaves); + // Create a the result starting with the pruned joint. + HybridBayesNet result; + result.emplace_shared(pruned.size(), pruned); /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree @@ -88,25 +69,34 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { * We can later check the HybridGaussianConditional for just nullptrs. */ - HybridBayesNet prunedBayesNetFragment; - - // Go through all the conditionals in the - // Bayes Net and prune them as per prunedDiscreteProbs. - for (auto &&conditional : copy) { - if (auto gm = conditional->asHybrid()) { + // Go through all the Gaussian conditionals in the Bayes Net and prune them as + // per pruned Discrete joint. + for (auto &&conditional : *this) { + if (auto hgc = conditional->asHybrid()) { // Make a copy of the hybrid Gaussian conditional and prune it! - auto prunedHybridGaussianConditional = gm->prune(prunedDiscreteProbs); + auto prunedHybridGaussianConditional = hgc->prune(pruned); // Type-erase and add to the pruned Bayes Net fragment. - prunedBayesNetFragment.push_back(prunedHybridGaussianConditional); - - } else { + result.push_back(prunedHybridGaussianConditional); + } else if (auto gc = conditional->asGaussian()) { // Add the non-HybridGaussianConditional conditional - prunedBayesNetFragment.push_back(conditional); + result.push_back(gc); } + // We ignore DiscreteConditional as they are already pruned and added. } - return prunedBayesNetFragment; + return result; +} + +/* ************************************************************************* */ +DiscreteBayesNet HybridBayesNet::discreteMarginal() const { + DiscreteBayesNet result; + for (auto &&conditional : *this) { + if (auto dc = conditional->asDiscrete()) { + result.push_back(dc); + } + } + return result; } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index a997174ecc..bba301be2f 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include #include @@ -77,16 +78,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { } /** - * Add a conditional using a shared_ptr, using implicit conversion to - * a HybridConditional. - * - * This is useful when you create a conditional shared pointer as you need it - * somewhere else. - * + * Move a HybridConditional into a shared pointer and add. + * Example: - * auto shared_ptr_to_a_conditional = - * std::make_shared(...); - * hbn.push_back(shared_ptr_to_a_conditional); + * HybridGaussianConditional conditional(...); + * hbn.push_back(conditional); // loses the original conditional */ void push_back(HybridConditional &&conditional) { factors_.push_back( @@ -124,14 +120,21 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { } /** - * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete - * value assignment. Note this corresponds to the Gaussian posterior p(X|M=m) - * of the continuous variables given the discrete assignment M=m. + * @brief Get the discrete Bayes Net P(M). As the hybrid Bayes net defines + * P(X,M) = P(X|M) P(M), this method returns the marginal distribution on the + * discrete variables. * - * @note Be careful, as any factors not Gaussian are ignored. + * @return discrete marginal as a DiscreteBayesNet. + */ + DiscreteBayesNet discreteMarginal() const; + + /** + * @brief Get the Gaussian Bayes net P(X|M=m) corresponding to a specific + * assignment m for the discrete variables M. As the hybrid Bayes net defines + * P(X,M) = P(X|M) P(M), this method returns the **posterior** p(X|M=m). * * @param assignment The discrete value assignment for the discrete keys. - * @return Gaussian posterior as a GaussianBayesNet + * @return Gaussian posterior P(X|M=m) as a GaussianBayesNet. */ GaussianBayesNet choose(const DiscreteValues &assignment) const; @@ -222,7 +225,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * * @note The joint P(X,M) is p(X|M) P(M) * Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x). - * Ideally we want log P(M|x) = log p(x|M) + log P(M) - log P(x), but + * Ideally we want log P(M|x) = log p(x|M) + log P(M) - log p(x), but * unfortunately log p(x) is expensive, so we compute the log of the * unnormalized posterior log P'(M|x) = log p(x|M) + log P(M) * @@ -255,13 +258,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// @} private: - /** - * @brief Prune all the discrete conditionals. - * - * @param maxNrLeaves - */ - DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves); - #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index f24c6fcb6d..1d22b3d73e 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -153,13 +153,14 @@ TEST(HybridBayesNet, Tiny) { EXPECT(assert_equal(one, bayesNet.optimize())); EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete()))); - // sample - std::mt19937_64 rng(42); - EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete())); + // sample. Not deterministic !!! TODO(Frank): figure out why + // std::mt19937_64 rng(42); + // EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete())); // prune auto pruned = bayesNet.prune(1); - EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); + CHECK(pruned.at(1)->asHybrid()); + EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents()); EXPECT(!pruned.equals(bayesNet)); // error @@ -402,49 +403,47 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { s.linearizedFactorGraph.eliminateSequential(); EXPECT_LONGS_EQUAL(7, posterior->size()); - size_t maxNrLeaves = 3; - DiscreteConditional discreteConditionals; - for (auto&& conditional : *posterior) { - if (conditional->isDiscrete()) { - discreteConditionals = - discreteConditionals * (*conditional->asDiscrete()); - } + DiscreteConditional joint; + for (auto&& conditional : posterior->discreteMarginal()) { + joint = joint * (*conditional); } - const DecisionTreeFactor::shared_ptr prunedDecisionTree = - std::make_shared( - discreteConditionals.prune(maxNrLeaves)); + + size_t maxNrLeaves = 3; + auto prunedDecisionTree = joint.prune(maxNrLeaves); #ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, - prunedDecisionTree->nrLeaves()); + prunedDecisionTree.nrLeaves()); #else - EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves()); + EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree.nrLeaves()); #endif // regression + // NOTE(Frank): I had to include *three* non-zeroes here now. DecisionTreeFactor::ADT potentials( - s.modes, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); - DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials); + s.modes, + std::vector{0, 0, 0, 0.28739288, 0, 0.43106901, 0, 0.2815381}); + DiscreteConditional expectedConditional(3, s.modes, potentials); // Prune! auto pruned = posterior->prune(maxNrLeaves); - // Functor to verify values against the expected_discrete_conditionals + // Functor to verify values against the expectedConditional auto checker = [&](const Assignment& assignment, double probability) -> double { // typecast so we can use this to get probability value DiscreteValues choices(assignment); - if (prunedDecisionTree->operator()(choices) == 0) { + if (prunedDecisionTree(choices) == 0) { EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); } else { - EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability, - 1e-9); + EXPECT_DOUBLES_EQUAL(expectedConditional(choices), probability, 1e-6); } return 0.0; }; // Get the pruned discrete conditionals as an AlgebraicDecisionTree - auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete(); + CHECK(pruned.at(0)->asDiscrete()); + auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete(); auto discrete_conditional_tree = std::dynamic_pointer_cast( pruned_discrete_conditionals); From acccef8024a86dc58a7f093d20aa67f1b0b6558b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 1 Oct 2024 13:51:09 -0700 Subject: [PATCH 22/22] Fix smoother --- gtsam/hybrid/HybridSmoother.cpp | 10 +++------- gtsam/hybrid/HybridSmoother.h | 2 +- gtsam/hybrid/tests/testHybridEstimation.cpp | 7 ++++++- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index b898c0520c..ca3e272521 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -72,21 +72,17 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, addConditionals(graph, hybridBayesNet_, ordering); // Eliminate. - HybridBayesNet::shared_ptr bayesNetFragment = - graph.eliminateSequential(ordering); + HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering); /// Prune if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - HybridBayesNet prunedBayesNetFragment = - bayesNetFragment->prune(*maxNrLeaves); - // Set the bayes net fragment to the pruned version - bayesNetFragment = std::make_shared(prunedBayesNetFragment); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves); } // Add the partial bayes net to the posterior bayes net. - hybridBayesNet_.add(*bayesNetFragment); + hybridBayesNet_.add(bayesNetFragment); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 267746ab62..66edf86d6b 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -39,7 +39,7 @@ class GTSAM_EXPORT HybridSmoother { * discrete factor on all discrete keys, plus all discrete factors in the * original graph. * - * \note If maxComponents is given, we look at the discrete factor resulting + * \note If maxNrLeaves is given, we look at the discrete factor resulting * from this elimination, and prune it and the Gaussian components * corresponding to the pruned choices. * diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 58decc695c..88d8be0bc5 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -109,6 +109,7 @@ TEST(HybridEstimation, IncrementalSmoother) { HybridGaussianFactorGraph linearized; + constexpr size_t maxNrLeaves = 3; for (size_t k = 1; k < K; k++) { // Motion Model graph.push_back(switching.nonlinearFactorGraph.at(k)); @@ -120,8 +121,12 @@ TEST(HybridEstimation, IncrementalSmoother) { linearized = *graph.linearize(initial); Ordering ordering = smoother.getOrdering(linearized); - smoother.update(linearized, 3, ordering); + smoother.update(linearized, maxNrLeaves, ordering); graph.resize(0); + + // Uncomment to print out pruned discrete marginal: + // smoother.hybridBayesNet().at(0)->asDiscrete()->dot("smoother_" + + // std::to_string(k)); } HybridValues delta = smoother.hybridBayesNet().optimize();