Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record cardinalities in DecisionTree #1868

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ namespace gtsam {
a->push_back(l1);
a->push_back(l2);
root_ = Choice::Unique(a);

cardinalities_map_[label] = 2;
}

/****************************************************************************/
Expand All @@ -508,6 +510,9 @@ namespace gtsam {
const Y& y2) {
if (labelC.second != 2) throw std::invalid_argument(
"DecisionTree: binary constructor called with non-binary label");

cardinalities_map_[labelC.first] = labelC.second;

auto a = std::make_shared<Choice>(labelC.first, 2);
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
a->push_back(l1);
Expand All @@ -521,6 +526,11 @@ namespace gtsam {
const std::vector<Y>& ys) {
// call recursive Create
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());

// Fill in cardinalities
for (auto&& [label, nrChoices] : labelCs) {
cardinalities_map_[label] = nrChoices;
}
}

/****************************************************************************/
Expand All @@ -535,6 +545,11 @@ namespace gtsam {

// now call recursive Create
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());

// Fill in cardinalities
for (auto&& [label, nrChoices] : labelCs) {
cardinalities_map_[label] = nrChoices;
}
}

/****************************************************************************/
Expand All @@ -550,13 +565,15 @@ namespace gtsam {
const DecisionTree& f0, const DecisionTree& f1) {
const std::vector<DecisionTree> functions{f0, f1};
root_ = compose(functions.begin(), functions.end(), label);

cardinalities_map_[label] = 2;
}

/****************************************************************************/
template <typename L, typename Y>
template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) {
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X)
: cardinalities_map_(other.allCardinalities()) {
// Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
Expand All @@ -569,6 +586,13 @@ namespace gtsam {
const std::map<M, L>& map, Func Y_of_X) {
auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);

// Fill in cardinalities
std::map<M, size_t> otherCardinalities = other.allCardinalities();
for (auto&& it = otherCardinalities.begin(); it != otherCardinalities.end();
it++) {
cardinalities_map_[L_of_M(it->first)] = it->second;
}
}

/****************************************************************************/
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ namespace gtsam {
return a == b;
}

/// Map of Keys and their cardinalities.
std::map<L, size_t> cardinalities_map_;

public:
using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>;
Expand Down Expand Up @@ -274,6 +277,12 @@ namespace gtsam {
/** evaluate */
const Y& operator()(const Assignment<L>& x) const;

/// Get the cardinalities for all the labels.
std::map<L, size_t> allCardinalities() const { return cardinalities_map_; }

/// Get cardinality for a specific label.
size_t nrChoices(L j) const { return cardinalities_map_.at(j); }

/**
* @brief Visit all leaves in depth-first fashion.
*
Expand Down Expand Up @@ -413,6 +422,7 @@ namespace gtsam {
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_NVP(root_);
ar& BOOST_SERIALIZATION_NVP(cardinalities_map_);
}
#endif
}; // DecisionTree
Expand Down
Loading