Skip to content

Commit

Permalink
Merge pull request #1857 from borglab/feature/posteriors
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Oct 6, 2024
2 parents 23f4282 + acccef8 commit b89e9c9
Show file tree
Hide file tree
Showing 21 changed files with 429 additions and 563 deletions.
11 changes: 11 additions & 0 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace gtsam {
return a / b;
}
static inline double id(const double& x) { return x; }
static inline double negate(const double& x) { return -x; }
};

AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
Expand Down Expand Up @@ -186,6 +187,16 @@ namespace gtsam {
return this->apply(g, &Ring::add);
}

/** negation */
AlgebraicDecisionTree operator-() const {
return this->apply(&Ring::negate);
}

/** subtract */
AlgebraicDecisionTree operator-(const AlgebraicDecisionTree& g) const {
return *this + (-g);
}

/** product */
AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const {
return this->apply(g, &Ring::mul);
Expand Down
4 changes: 2 additions & 2 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ namespace gtsam {

/// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
double evaluate(const Assignment<Key>& values) const {
return ADT::operator()(values);
}

Expand All @@ -155,7 +155,7 @@ namespace gtsam {
return apply(f, safe_div);
}

/// Convert into a decisiontree
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }

/// Create new factor by summing all values with the same separator values
Expand Down
54 changes: 33 additions & 21 deletions gtsam/discrete/tests/testAlgebraicDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
* -------------------------------------------------------------------------- */

/*
* @file testDecisionTree.cpp
* @brief Develop DecisionTree
* @author Frank Dellaert
* @date Mar 6, 2011
* @file testAlgebraicDecisionTree.cpp
* @brief Unit tests for Algebraic decision tree
* @author Frank Dellaert
* @date Mar 6, 2011
*/

#include <gtsam/base/Testable.h>
Expand Down Expand Up @@ -46,23 +46,35 @@ void dot(const T& f, const string& filename) {
#endif
}

/** I can't get this to work !
class Mul: std::function<double(const double&, const double&)> {
inline double operator()(const double& a, const double& b) {
return a * b;
}
};
// If second argument of binary op is Leaf
template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op));
return Unique(cache, h);
}
*/
/* ************************************************************************** */
// Test arithmetic:
TEST(ADT, arithmetic) {
DiscreteKey A(0, 2), B(1, 2);
ADT zero{0}, one{1};
ADT a(A, 1, 2);
ADT b(B, 3, 4);

// Addition
CHECK(assert_equal(a, zero + a));

// Negate and subtraction
CHECK(assert_equal(-a, zero - a));
CHECK(assert_equal({zero}, a - a));
CHECK(assert_equal(a + b, b + a));
CHECK(assert_equal({A, 3, 4}, a + 2));
CHECK(assert_equal({B, 1, 2}, b - 2));

// Multiplication
CHECK(assert_equal(zero, zero * a));
CHECK(assert_equal(zero, a * zero));
CHECK(assert_equal(a, one * a));
CHECK(assert_equal(a, a * one));
CHECK(assert_equal(a * b, b * a));

// division
// CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning
CHECK(assert_equal(b, (a * b) / a));
}

/* ************************************************************************** */
// instrumented operators
Expand Down
237 changes: 45 additions & 192 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
*/

#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>

#include <memory>

// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);

Expand All @@ -38,135 +41,26 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
}

/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the hybrid Gaussian conditional.
std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());

auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// This corresponds to 0 probability
double pruned_prob = 0.0;

// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the hybrid Gaussian conditional has the same
// discrete keys as the decision tree.
if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDiscreteProbs(values) == 0) {
return pruned_prob;
} else {
return probability;
}
} else {
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
// get a `values` which doesn't have the full set of keys.
std::set<Key> valuesKeys;
for (auto kvp : values) {
valuesKeys.insert(kvp.first);
}
std::set<Key> conditionalKeys;
for (auto kvp : conditionalKeySet) {
conditionalKeys.insert(kvp.first);
}
// If true, then values is missing some keys
if (conditionalKeys != valuesKeys) {
// Get the keys present in conditionalKeys but not in valuesKeys
std::vector<Key> missing_keys;
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
valuesKeys.begin(), valuesKeys.end(),
std::back_inserter(missing_keys));
// Insert missing keys with a default assignment.
for (auto missing_key : missing_keys) {
values[missing_key] = 0;
}
}
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal();

// Now we generate the full assignment by enumerating
// over all keys in the prunedDiscreteProbs.
// First we find the differing keys
std::vector<DiscreteKey> set_diff;
std::set_difference(discreteProbsKeySet.begin(),
discreteProbsKeySet.end(), conditionalKeySet.begin(),
conditionalKeySet.end(),
std::back_inserter(set_diff));

// Now enumerate over all assignments of the differing keys
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment);

// If any one of the sub-branches are non-zero,
// we need this probability.
if (prunedDiscreteProbs(augmented_values) > 0.0) {
return probability;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return pruned_prob;
}
};
return pruner;
}

/* ************************************************************************* */
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
// The joint discrete probability.
DiscreteConditional discreteProbs;

std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;

for (size_t i = 0; i < this->size(); i++) {
auto conditional = this->at(i);
if (conditional->isDiscrete()) {
discreteProbs = discreteProbs * (*conditional->asDiscrete());

Ordering conditional_keys(conditional->frontals());
discrete_frontals += conditional_keys;
discrete_factor_idxs.push_back(i);
}
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (auto &&conditional : marginal) {
joint = joint * (*conditional);
}

const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);

// Eliminate joint probability back into conditionals
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);

// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}

return prunedDiscreteProbs;
}
// Prune the joint. NOTE: again, possibly quite expensive.
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves);
// Create a the result starting with the pruned joint.
HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);

/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
Expand All @@ -175,28 +69,34 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
* We can later check the HybridGaussianConditional for just nullptrs.
*/

HybridBayesNet prunedBayesNetFragment;

// Go through all the conditionals in the
// Bayes Net and prune them as per prunedDiscreteProbs.
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned Discrete joint.
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
if (auto hgc = conditional->asHybrid()) {
// Make a copy of the hybrid Gaussian conditional and prune it!
auto prunedHybridGaussianConditional =
std::make_shared<HybridGaussianConditional>(*gm);
prunedHybridGaussianConditional->prune(
prunedDiscreteProbs); // imperative :-(
auto prunedHybridGaussianConditional = hgc->prune(pruned);

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedHybridGaussianConditional);

} else {
result.push_back(prunedHybridGaussianConditional);
} else if (auto gc = conditional->asGaussian()) {
// Add the non-HybridGaussianConditional conditional
prunedBayesNetFragment.push_back(conditional);
result.push_back(gc);
}
// We ignore DiscreteConditional as they are already pruned and added.
}

return prunedBayesNetFragment;
return result;
}

/* ************************************************************************* */
DiscreteBayesNet HybridBayesNet::discreteMarginal() const {
DiscreteBayesNet result;
for (auto &&conditional : *this) {
if (auto dc = conditional->asDiscrete()) {
result.push_back(dc);
}
}
return result;
}

/* ************************************************************************* */
Expand Down Expand Up @@ -291,66 +191,19 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->errorTree(continuousValues);

} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the result tree.
result = result.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete error in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->error(DiscreteValues(assignment));
});
}
}

return result;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, select based on assignment and compute
// logProbability.
result = result + gm->logProbability(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the (double) logProbability and add it to the
// result
double logProbability = gc->logProbability(continuousValues);
// Add the computed logProbability to every leaf of the logProbability
// tree.
result = result.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete logProbability in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->logProbability(DiscreteValues(assignment));
});
}
result = result + conditional->errorTree(continuousValues);
}

return result;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
return tree.apply([](double log) { return exp(log); });
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> p =
errors.apply([](double error) { return exp(-error); });
return p / p.sum();
}

/* ************************************************************************* */
Expand Down
Loading

0 comments on commit b89e9c9

Please sign in to comment.