Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize inverse function #8

Merged
merged 5 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
Language: Cpp
BasedOnStyle: Google
ColumnLimit: 100
DerivePointerAlignment: false
PointerAlignment: Left
...
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ Refer to [`transforms_graph.h`](include/transforms_graph/transforms_graph.h) for
## Required types
The `TransformsGraph` requires two template parameters:
1. `Transform` type. This is the type that stores the transforms themselves (i.e., poses). This type should have the following defined:
- A default constructor that sets the transform to the identity transform
- A valid Multiplication operator `*` operator (i.e., `T1 = T2 * T2`)
- A valid `Transform inverse() const` method (i.e., `T1.inverse() * T1` should return an identity transform)
- An output stream operator `<<` (i.e., `std::cout << T`)
- A default constructor that sets the transform to the identity transform.
- A valid Multiplication operator `*` operator (i.e., `T1 = T2 * T2`).
- A valid `Transform inverse() const` method (i.e., `T1.inverse() * T1` should return an identity transform). Alternatively, a function `Transform inverse(Transform)` can be injected/passed to the `TransformsGraph` upon construction.
- An output stream operator `<<` (i.e., `std::cout << T`).
2. `Frame` type (default set to `char`). This is the type that keeps track of the frames. The type should have the following defined:
- Greater-than comparison operator `>` (i.e., `frame_i > frame_j`)
- An output stream operator `<<` (i.e., `std::cout << T`)
- Greater-than comparison operator `>` (i.e., `frame_i > frame_j`).
- An output stream operator `<<` (i.e., `std::cout << T`).
3. `Inv` function object (defaults to `Transform::inverse`). The function object is passed in the constructor.

The classes from [Sophus](https://github.com/strasdat/Sophus) (e.g., `Sophus::SE2d`) and [Eigen](https://eigen.tuxfamily.org/dox/group__TutorialGeometry.html) (e.g., `Eigen::Affine2d`) already satisfy the `Transform` requirements, except for the output stream operator `<<` requirement.

Expand Down
2 changes: 1 addition & 1 deletion examples/eigen_pose_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Pose {

Pose inverse() const {
Pose p;
p.pose_ = std::move(pose_.inverse());
p.pose_ = pose_.inverse();
return p;
}
Eigen::Affine2d Affine() const { return pose_; }
Expand Down
4 changes: 2 additions & 2 deletions examples/minimal_graph_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class Displacement {
Displacement() { d_ = 0; }
Displacement(double d) : d_(d) {}
double x() const { return d_; }
Displacement inverse() const { return Displacement(-d_); }

private:
double d_;
Expand All @@ -43,7 +42,8 @@ int main(int argc, char* argv[]) {
using Transform = Displacement;

// Construct a graph that consists of two unconnected subgraphs
tg::TransformsGraph<Transform, Frame> transforms;
auto displacement_inverse = [](const Transform& t) -> Transform { return -t.x(); };
tg::TransformsGraph<Transform, Frame> transforms(100, displacement_inverse);
transforms.AddTransform('a', 'b', 1);
transforms.AddTransform('a', 'c', 2);
transforms.AddTransform('b', 'd', 3);
Expand Down
6 changes: 4 additions & 2 deletions include/transforms_graph/graph_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ namespace tg {
*/
template <typename Node, typename Graph = std::unordered_map<Node, std::unordered_set<Node>>>
std::vector<Node> DFS(const Graph& graph, Node start, Node end) {
std::vector<Node> path;
std::unordered_set<Node> visited;
std::unordered_map<Node, Node> parent;
std::stack<Node> stack;
Expand All @@ -43,6 +42,8 @@ std::vector<Node> DFS(const Graph& graph, Node start, Node end) {
found_solution = true;
break;
}

// Mark as visited
if (visited.count(current)) continue;
visited.insert(current);

Expand All @@ -55,7 +56,8 @@ std::vector<Node> DFS(const Graph& graph, Node start, Node end) {

if (!found_solution) return {};

// Found the end
// Get path from start -> end
std::vector<Node> path;
Node node = end;
while (node != start) {
path.push_back(node);
Expand Down
28 changes: 22 additions & 6 deletions include/transforms_graph/transforms_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ namespace tg {
* @brief Transform graph class
* @tparam Transform Transform/pose type. Should have `*`, `<<`, and `inverse()` defined
* @tparam Frame Frame type (e.g. char, int, etc.). Should have `<<` defined.
* @tparam Inv Function object, where the instance takes an argument of type Transform and returns
* an inverse Transform.
*/
template <typename Transform, typename Frame = char>
template <typename Transform, typename Frame = char,
typename Inv = std::function<Transform(Transform)>>
class TransformsGraph {
public:
/** Id/key used to store transforms in the Transforms map */
Expand All @@ -32,6 +35,9 @@ class TransformsGraph {
/** Raw transforms */
using Transforms = std::unordered_map<TransformId, Transform>;

/** Inverse function object */
using TransformInverse = Inv;

/** Adjacency matrix of adjacent frames */
using AdjacentFrames = std::unordered_map<Frame, std::unordered_set<Frame>>;

Expand All @@ -45,8 +51,13 @@ class TransformsGraph {
* @details The maximum number of frames cannot be changed after construction
*
* @param[in] max_frames Maximum number of frames allowed in the graph
* @param[in] transform_inverse Function object to invert a transform. Should take a
* transform as an argument and return its inverse.
*/
TransformsGraph(int max_frames = 100) : max_frames_(max_frames) {}
TransformsGraph(int max_frames = 100,
TransformInverse transform_inverse = std::bind(&Transform::inverse,
std::placeholders::_1))
: max_frames_(max_frames), transform_inverse_(transform_inverse) {}

/**
* @brief Get maximum number allowed in the transform graph
Expand Down Expand Up @@ -130,7 +141,7 @@ class TransformsGraph {
const auto transform_id = ComputeTransformId(parent, child);
auto transform = raw_transforms_.at(transform_id);

return ShouldInvertFrames(parent, child) ? transform.inverse() : transform;
return ShouldInvertFrames(parent, child) ? transform_inverse_(transform) : transform;
}

/**
Expand Down Expand Up @@ -160,7 +171,7 @@ class TransformsGraph {
const auto transform_id = ComputeTransformId(prev, frame);
auto T_prev_curr = raw_transforms_.at(transform_id);
if (ShouldInvertFrames(prev, frame)) {
T_prev_curr = T_prev_curr.inverse();
T_prev_curr = transform_inverse_(T_prev_curr);
}

T_parent_child = T_parent_child * T_prev_curr;
Expand Down Expand Up @@ -354,7 +365,8 @@ class TransformsGraph {
throw std::runtime_error("Transform does not exist in the graph");
}
const auto transform_id = ComputeTransformId(parent, child);
raw_transforms_[transform_id] = ShouldInvertFrames(parent, child) ? pose.inverse() : pose;
raw_transforms_[transform_id] =
ShouldInvertFrames(parent, child) ? transform_inverse_(pose) : pose;
}

/**
Expand Down Expand Up @@ -457,7 +469,8 @@ class TransformsGraph {
if (!HasFrame(child)) AddFrame(child);

const auto transform_id = ComputeTransformId(parent, child);
raw_transforms_[transform_id] = ShouldInvertFrames(parent, child) ? pose.inverse() : pose;
raw_transforms_[transform_id] =
ShouldInvertFrames(parent, child) ? transform_inverse_(pose) : pose;
}

/**
Expand Down Expand Up @@ -508,6 +521,9 @@ class TransformsGraph {
/** Maximum number of frames expected to be in the graph */
int max_frames_ = 100;

/** Function object to invert transform */
TransformInverse transform_inverse_;

/** Acyclic graph where the vertices are the frames and the edges are transforms between the two
* frames */
AdjacentFrames adjacent_frames_;
Expand Down