Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hybrid bug, double performance #1857

Merged
merged 22 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace gtsam {
return a / b;
}
static inline double id(const double& x) { return x; }
static inline double negate(const double& x) { return -x; }
};

AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
Expand Down Expand Up @@ -186,6 +187,16 @@ namespace gtsam {
return this->apply(g, &Ring::add);
}

/** negation */
AlgebraicDecisionTree operator-() const {
return this->apply(&Ring::negate);
}

/** subtract */
AlgebraicDecisionTree operator-(const AlgebraicDecisionTree& g) const {
return *this + (-g);
}

/** product */
AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const {
return this->apply(g, &Ring::mul);
Expand Down
4 changes: 2 additions & 2 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ namespace gtsam {

/// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
double evaluate(const Assignment<Key>& values) const {
return ADT::operator()(values);
}

Expand All @@ -155,7 +155,7 @@ namespace gtsam {
return apply(f, safe_div);
}

/// Convert into a decisiontree
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }

/// Create new factor by summing all values with the same separator values
Expand Down
54 changes: 33 additions & 21 deletions gtsam/discrete/tests/testAlgebraicDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
* -------------------------------------------------------------------------- */

/*
* @file testDecisionTree.cpp
* @brief Develop DecisionTree
* @author Frank Dellaert
* @date Mar 6, 2011
* @file testAlgebraicDecisionTree.cpp
* @brief Unit tests for Algebraic decision tree
* @author Frank Dellaert
* @date Mar 6, 2011
*/

#include <gtsam/base/Testable.h>
Expand Down Expand Up @@ -46,23 +46,35 @@ void dot(const T& f, const string& filename) {
#endif
}

/** I can't get this to work !
class Mul: std::function<double(const double&, const double&)> {
inline double operator()(const double& a, const double& b) {
return a * b;
}
};

// If second argument of binary op is Leaf
template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op));
return Unique(cache, h);
}
*/
/* ************************************************************************** */
// Test arithmetic:
TEST(ADT, arithmetic) {
DiscreteKey A(0, 2), B(1, 2);
ADT zero{0}, one{1};
ADT a(A, 1, 2);
ADT b(B, 3, 4);

// Addition
CHECK(assert_equal(a, zero + a));

// Negate and subtraction
CHECK(assert_equal(-a, zero - a));
CHECK(assert_equal({zero}, a - a));
CHECK(assert_equal(a + b, b + a));
CHECK(assert_equal({A, 3, 4}, a + 2));
CHECK(assert_equal({B, 1, 2}, b - 2));

// Multiplication
CHECK(assert_equal(zero, zero * a));
CHECK(assert_equal(zero, a * zero));
CHECK(assert_equal(a, one * a));
CHECK(assert_equal(a, a * one));
CHECK(assert_equal(a * b, b * a));

// division
// CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning
CHECK(assert_equal(b, (a * b) / a));
}

/* ************************************************************************** */
// instrumented operators
Expand Down
237 changes: 45 additions & 192 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
*/

#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>

#include <memory>

// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);

Expand All @@ -38,135 +41,26 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
}

/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the hybrid Gaussian conditional.
std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());

auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// This corresponds to 0 probability
double pruned_prob = 0.0;

// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the hybrid Gaussian conditional has the same
// discrete keys as the decision tree.
if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDiscreteProbs(values) == 0) {
return pruned_prob;
} else {
return probability;
}
} else {
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
// get a `values` which doesn't have the full set of keys.
std::set<Key> valuesKeys;
for (auto kvp : values) {
valuesKeys.insert(kvp.first);
}
std::set<Key> conditionalKeys;
for (auto kvp : conditionalKeySet) {
conditionalKeys.insert(kvp.first);
}
// If true, then values is missing some keys
if (conditionalKeys != valuesKeys) {
// Get the keys present in conditionalKeys but not in valuesKeys
std::vector<Key> missing_keys;
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
valuesKeys.begin(), valuesKeys.end(),
std::back_inserter(missing_keys));
// Insert missing keys with a default assignment.
for (auto missing_key : missing_keys) {
values[missing_key] = 0;
}
}
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal();

// Now we generate the full assignment by enumerating
// over all keys in the prunedDiscreteProbs.
// First we find the differing keys
std::vector<DiscreteKey> set_diff;
std::set_difference(discreteProbsKeySet.begin(),
discreteProbsKeySet.end(), conditionalKeySet.begin(),
conditionalKeySet.end(),
std::back_inserter(set_diff));

// Now enumerate over all assignments of the differing keys
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment);

// If any one of the sub-branches are non-zero,
// we need this probability.
if (prunedDiscreteProbs(augmented_values) > 0.0) {
return probability;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return pruned_prob;
}
};
return pruner;
}

/* ************************************************************************* */
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
// The joint discrete probability.
DiscreteConditional discreteProbs;

std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;

for (size_t i = 0; i < this->size(); i++) {
auto conditional = this->at(i);
if (conditional->isDiscrete()) {
discreteProbs = discreteProbs * (*conditional->asDiscrete());

Ordering conditional_keys(conditional->frontals());
discrete_frontals += conditional_keys;
discrete_factor_idxs.push_back(i);
}
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (auto &&conditional : marginal) {
joint = joint * (*conditional);
}

const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);

// Eliminate joint probability back into conditionals
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);

// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}

return prunedDiscreteProbs;
}
// Prune the joint. NOTE: again, possibly quite expensive.
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves);
// Create a the result starting with the pruned joint.
HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);

/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
Expand All @@ -175,28 +69,34 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
* We can later check the HybridGaussianConditional for just nullptrs.
*/

HybridBayesNet prunedBayesNetFragment;

// Go through all the conditionals in the
// Bayes Net and prune them as per prunedDiscreteProbs.
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned Discrete joint.
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
if (auto hgc = conditional->asHybrid()) {
// Make a copy of the hybrid Gaussian conditional and prune it!
auto prunedHybridGaussianConditional =
std::make_shared<HybridGaussianConditional>(*gm);
prunedHybridGaussianConditional->prune(
prunedDiscreteProbs); // imperative :-(
auto prunedHybridGaussianConditional = hgc->prune(pruned);

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedHybridGaussianConditional);

} else {
result.push_back(prunedHybridGaussianConditional);
} else if (auto gc = conditional->asGaussian()) {
// Add the non-HybridGaussianConditional conditional
prunedBayesNetFragment.push_back(conditional);
result.push_back(gc);
}
// We ignore DiscreteConditional as they are already pruned and added.
}

return prunedBayesNetFragment;
return result;
}

/* ************************************************************************* */
DiscreteBayesNet HybridBayesNet::discreteMarginal() const {
DiscreteBayesNet result;
for (auto &&conditional : *this) {
if (auto dc = conditional->asDiscrete()) {
result.push_back(dc);
}
}
return result;
}

/* ************************************************************************* */
Expand Down Expand Up @@ -291,66 +191,19 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->errorTree(continuousValues);

} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the result tree.
result = result.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete error in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->error(DiscreteValues(assignment));
});
}
}

return result;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, select based on assignment and compute
// logProbability.
result = result + gm->logProbability(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the (double) logProbability and add it to the
// result
double logProbability = gc->logProbability(continuousValues);
// Add the computed logProbability to every leaf of the logProbability
// tree.
result = result.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete logProbability in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->logProbability(DiscreteValues(assignment));
});
}
result = result + conditional->errorTree(continuousValues);
}

return result;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
return tree.apply([](double log) { return exp(log); });
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> p =
errors.apply([](double error) { return exp(-error); });
return p / p.sum();
}

/* ************************************************************************* */
Expand Down
Loading
Loading