Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored hybrid inference #1863

Merged
merged 56 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
3797996
Store the values
dellaert Sep 28, 2024
14d1594
implement errorTree for HybridNonlinearFactorGraph
varunagrawal Oct 1, 2024
1bb5b95
discretePosterior in HNFG
dellaert Oct 2, 2024
8b3dfd8
New product factor class
dellaert Oct 2, 2024
ed9a216
Switch to using HybridGaussianProductFactor
dellaert Oct 4, 2024
55ca557
Fix conditional==null bug
dellaert Oct 5, 2024
9254029
Squashed commit
dellaert Oct 6, 2024
584a71f
Product now has scalars
dellaert Oct 6, 2024
e1c0d0e
operator returns pairs, extensive switching test
dellaert Oct 7, 2024
04cfb06
Cherry-pick Varun's bugfix
dellaert Oct 7, 2024
b3c6980
Don't normalize probabilities for a mere DiscreteFactor
dellaert Oct 7, 2024
586b177
Extensive new API test
dellaert Oct 7, 2024
518ea81
No more hiding!
dellaert Oct 7, 2024
f0770c2
Less regression, more API tests for s(3).
dellaert Oct 8, 2024
bcd94e3
Store negLogConstant[i] - negLogConstant_
dellaert Oct 8, 2024
0f48efb
asProductFactor from base class!
dellaert Oct 8, 2024
5241614
Fixed discreteEimination
dellaert Oct 8, 2024
88b8dc9
Merge branch 'develop' into feature/no_hiding
dellaert Oct 8, 2024
9f7ccbb
Minimize formatting changes
dellaert Oct 8, 2024
8e85b68
formatting fixes
varunagrawal Oct 8, 2024
f39f678
add type info
varunagrawal Oct 8, 2024
7603cd4
missing include and formatting
varunagrawal Oct 8, 2024
874ba67
update comment
varunagrawal Oct 8, 2024
21b4c4c
improve HybridGaussianProductFactor
varunagrawal Oct 8, 2024
8b8466e
formatting testHybridGaussianFactorGraph
varunagrawal Oct 8, 2024
02d5421
add GTSAM_EXPORT to HybridGaussianProductFactor
varunagrawal Oct 8, 2024
711a07c
small fix
varunagrawal Oct 8, 2024
d58bd6c
get HybridBayesNet compiling on Windows
varunagrawal Oct 8, 2024
0dbf13f
Another Windows fix
varunagrawal Oct 8, 2024
c60d342
include <iostream> in HybridConditional.cpp
varunagrawal Oct 8, 2024
c08d6bd
maybe it's <string>
varunagrawal Oct 8, 2024
6da1b01
include <string> in HybridBayesTree.cpp
varunagrawal Oct 8, 2024
03c467f
add more includes to try and debug this
varunagrawal Oct 8, 2024
8e29d57
remove includes
varunagrawal Oct 9, 2024
4980797
improved type aliasing
varunagrawal Oct 9, 2024
cd3e0f3
include sstream in HybridGaussianProductFactor
varunagrawal Oct 9, 2024
4ae5596
add back PotentiallyPrunedComponentError as an inline function
varunagrawal Oct 9, 2024
26c9dcb
formatting
varunagrawal Oct 9, 2024
4df266a
replace sstream with string
varunagrawal Oct 9, 2024
436524a
use cout instead of stringstream
varunagrawal Oct 9, 2024
59f97d6
Merge pull request #1865 from borglab/feature/no_hiding-2
dellaert Oct 9, 2024
19fdb43
Pure Google style in clang-format
dellaert Oct 9, 2024
34bb1d0
Shift error values before exponentiating
dellaert Oct 9, 2024
caddc73
fix printing tests
varunagrawal Oct 9, 2024
205eb18
add serialization to HybridGaussianProductFactor
varunagrawal Oct 9, 2024
4f74735
checking if print is the problem
varunagrawal Oct 9, 2024
8a650b6
undo print and remove extra includes
varunagrawal Oct 9, 2024
e87d1fb
comment out serialization
varunagrawal Oct 9, 2024
99a39e6
undo test change
varunagrawal Oct 9, 2024
752e10f
delete constructor from string to fix issue
varunagrawal Oct 9, 2024
1ac8a6e
fix vector type
varunagrawal Oct 9, 2024
f4bf280
implement constructor
varunagrawal Oct 9, 2024
55dc3f5
experiment with removing the constructor
varunagrawal Oct 9, 2024
4d707e7
Revert "experiment with removing the constructor"
varunagrawal Oct 9, 2024
93ec276
implement dummy >> operator
varunagrawal Oct 9, 2024
95f053f
Merge pull request #1866 from borglab/more-fixes
dellaert Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
BasedOnStyle: Google

BinPackArguments: false
BinPackParameters: false
ColumnLimit: 100
DerivePointerAlignment: false
IncludeBlocks: Preserve
PointerAlignment: Left
3 changes: 1 addition & 2 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/linear/GaussianJunctionTree.h>

#include <memory>

#include "gtsam/hybrid/HybridConditional.h"

namespace gtsam {

// Instantiate base class
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @file HybridConditional.cpp
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/

#include <gtsam/hybrid/HybridConditional.h>
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @file HybridConditional.h
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/

#pragma once
Expand Down
3 changes: 0 additions & 3 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ namespace gtsam {

class HybridValues;

/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;

KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
Expand Down
79 changes: 31 additions & 48 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@

namespace gtsam {
/* *******************************************************************************/
/**
* @brief Helper struct for constructing HybridGaussianConditional objects
*
* This struct contains the following fields:
* - nrFrontals: Optional size_t for number of frontal variables
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
* - conditionals: Conditionals for storing conditionals. TODO(frank): kill!
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
* constructor
*/
struct HybridGaussianConditional::Helper {
std::optional<size_t> nrFrontals;
FactorValuePairs pairs;
Expand Down Expand Up @@ -68,16 +78,12 @@ struct HybridGaussianConditional::Helper {
explicit Helper(const Conditionals &conditionals)
: conditionals(conditionals),
minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr &c) -> GaussianFactorValuePair {
double value = 0.0;
if (c) {
if (!nrFrontals.has_value()) {
nrFrontals = c->nrFrontals();
}
value = c->negLogConstant();
minNegLogConstant = std::min(minNegLogConstant, value);
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
if (!nrFrontals) nrFrontals = gc->nrFrontals();
double value = gc->negLogConstant();
minNegLogConstant = std::min(minNegLogConstant, value);
return {gc, value};
};
pairs = FactorValuePairs(conditionals, func);
if (!nrFrontals.has_value()) {
Expand All @@ -91,7 +97,14 @@ struct HybridGaussianConditional::Helper {
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const Helper &helper)
: BaseFactor(discreteParents, helper.pairs),
: BaseFactor(discreteParents,
FactorValuePairs(helper.pairs,
[&](const GaussianFactorValuePair &
pair) { // subtract minNegLogConstant
return GaussianFactorValuePair{
pair.first,
pair.second - helper.minNegLogConstant};
})),
BaseConditional(*helper.nrFrontals),
conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {}
Expand Down Expand Up @@ -135,29 +148,6 @@ HybridGaussianConditional::conditionals() const {
return conditionals_;
}

/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
const {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
// First check if conditional has not been pruned
if (gc) {
const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_;
// If there is a difference in the covariances, we need to account for
// that since the error is dependent on the mode.
if (Cgm_Kgcm > 0.0) {
// We add a constant factor which will be used when computing
// the probability of the discrete variables.
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = std::make_shared<JacobianFactor>(c);
return GaussianFactorGraph{gc, constantFactor};
}
}
return GaussianFactorGraph{gc};
};
return {conditionals_, wrap};
}

/* *******************************************************************************/
size_t HybridGaussianConditional::nrComponents() const {
size_t total = 0;
Expand Down Expand Up @@ -192,19 +182,18 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,

// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(
e->conditionals_, [tol](const auto &f1, const auto &f2) {
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
});
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return (!f1 && !f2) ||
(f1 && f2 && f1->equals(*f2, tol));
});
}

/* *******************************************************************************/
void HybridGaussianConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
BaseConditional::print("", formatter);
std::cout << " Discrete Keys = ";
for (auto &dk : discreteKeys()) {
Expand Down Expand Up @@ -270,13 +259,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
-> GaussianFactorValuePair {
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
if (Cgm_Kgcm == 0.0) {
return {likelihood_m, 0.0};
} else {
// Add a constant to the likelihood in case the noise models
// are not all equal.
return {likelihood_m, Cgm_Kgcm};
}
return {likelihood_m, Cgm_Kgcm};
});
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
likelihoods);
Expand Down
3 changes: 0 additions & 3 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper);

/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

Expand Down
Loading
Loading