-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DistributedTree Nearest query with callback #737
base: master
Are you sure you want to change the base?
Changes from all commits
8223d33
a6cf7e6
5cf9a5a
6ce6cd2
a552d22
8587a9d
3928d3f
a637f27
ff91ac3
6aee1b5
1e6ba19
68bc0d5
eda8f6b
8379b23
9857b23
484db22
526d581
c4ad321
c666cfb
82505da
86b426a
e0d0d79
0ac9cc4
38bf593
5a26209
980404e
44ee75d
b5c622d
1c4beaf
8d40931
b72b0c8
9402639
d26ce07
3f66b82
7ac6480
2d19d6e
e2a899e
74f4a30
7452785
85ccd16
e6f43d5
591b89e
0baa3fd
3ffb9a7
901cd91
0601e4c
ec544e0
071f409
0d11bf1
dac4018
775e55b
fc59eca
498d407
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
add_executable(ArborX_DistributedTree_KNNCallback.exe distributed_knn_callback.cpp) | ||
target_link_libraries(ArborX_DistributedTree_KNNCallback.exe ArborX::ArborX) | ||
add_test(NAME ArborX_DistributedTree_KNNCallback_Example COMMAND ${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} ${MPIEXEC_MAX_NUMPROCS} ${MPIEXEC_PREFLAGS} ./ArborX_DistributedTree_KNNCallback.exe ${MPIEXEC_POSTFLAGS}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/**************************************************************************** | ||
* Copyright (c) 2017-2022 by the ArborX authors * | ||
* All rights reserved. * | ||
* * | ||
* This file is part of the ArborX library. ArborX is * | ||
* distributed under a BSD 3-clause license. For the licensing terms see * | ||
* the LICENSE file in the top-level directory. * | ||
* * | ||
* SPDX-License-Identifier: BSD-3-Clause * | ||
****************************************************************************/ | ||
|
||
#include <ArborX.hpp> | ||
|
||
#include <Kokkos_Core.hpp> | ||
|
||
#include <cstdarg> | ||
#include <cstdio> | ||
#include <iostream> | ||
#include <random> | ||
#include <vector> | ||
|
||
#include <mpi.h> | ||
|
||
using ExecutionSpace = Kokkos::DefaultExecutionSpace; | ||
using MemorySpace = ExecutionSpace::memory_space; | ||
|
||
namespace Example | ||
{ | ||
template <class Points> | ||
struct Nearest | ||
{ | ||
Points points; | ||
int k; | ||
int mpi_rank; | ||
}; | ||
template <class Points> | ||
Nearest(Points const &, int, int) -> Nearest<Points>; | ||
|
||
struct IndexAndRank | ||
{ | ||
int index; | ||
int rank; | ||
}; | ||
|
||
template <typename DeviceType> | ||
struct PrintAndInsert | ||
{ | ||
Kokkos::View<ArborX::Point *, DeviceType> points; | ||
int mpi_rank; | ||
|
||
PrintAndInsert(Kokkos::View<ArborX::Point *, DeviceType> const &points_, | ||
int mpi_rank_) | ||
: points(points_) | ||
, mpi_rank(mpi_rank_) | ||
{} | ||
|
||
template <typename Predicate, typename OutputFunctor> | ||
KOKKOS_FUNCTION void operator()([[maybe_unused]] Predicate const &predicate, | ||
int primitive_index, | ||
OutputFunctor const &out) const | ||
{ | ||
#ifndef KOKKOS_ENABLE_SYCL | ||
auto data = ArborX::getData(predicate); | ||
auto const &point = points(primitive_index); | ||
printf("Match for query %d from MPI rank %d on MPI rank %d for " | ||
"point %f,%f,%f with index %d\n", | ||
data.index, data.rank, mpi_rank, point[0], point[1], point[2], | ||
primitive_index); | ||
#endif | ||
|
||
out({primitive_index, mpi_rank}); | ||
} | ||
}; | ||
|
||
} // namespace Example | ||
|
||
template <class Points> | ||
struct ArborX::AccessTraits<Example::Nearest<Points>, ArborX::PredicatesTag> | ||
{ | ||
static KOKKOS_FUNCTION std::size_t size(Example::Nearest<Points> const &x) | ||
{ | ||
return x.points.extent(0); | ||
} | ||
static KOKKOS_FUNCTION auto get(Example::Nearest<Points> const &x, int i) | ||
{ | ||
return attach(ArborX::nearest(x.points(i), x.k), | ||
Example::IndexAndRank{i, x.mpi_rank}); | ||
} | ||
using memory_space = MemorySpace; | ||
}; | ||
|
||
int main(int argc, char *argv[]) | ||
{ | ||
MPI_Init(&argc, &argv); | ||
Kokkos::initialize(argc, argv); | ||
{ | ||
MPI_Comm comm = MPI_COMM_WORLD; | ||
int comm_rank; | ||
MPI_Comm_rank(comm, &comm_rank); | ||
int comm_size; | ||
MPI_Comm_size(comm, &comm_size); | ||
ArborX::Point lower_left_corner = {static_cast<float>(comm_rank), | ||
static_cast<float>(comm_rank), | ||
static_cast<float>(comm_rank)}; | ||
ArborX::Point center = {static_cast<float>(comm_rank) + .5f, | ||
static_cast<float>(comm_rank) + .5f, | ||
static_cast<float>(comm_rank) + .5f}; | ||
std::vector points = {lower_left_corner, center}; | ||
auto points_device = Kokkos::create_mirror_view_and_copy( | ||
MemorySpace{}, | ||
Kokkos::View<ArborX::Point *, Kokkos::HostSpace, | ||
Kokkos::MemoryUnmanaged>(points.data(), points.size())); | ||
|
||
ExecutionSpace exec; | ||
ArborX::DistributedTree<MemorySpace> tree(comm, exec, points_device); | ||
|
||
Kokkos::View<Example::IndexAndRank *, MemorySpace> values("values", 0); | ||
Kokkos::View<int *, MemorySpace> offsets("offsets", 0); | ||
tree.query(exec, Example::Nearest{points_device, 3, comm_rank}, | ||
Example::PrintAndInsert<MemorySpace>(points_device, comm_rank), | ||
values, offsets); | ||
|
||
auto host_values = | ||
Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, values); | ||
auto host_offsets = | ||
Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, offsets); | ||
for (unsigned int i = 0; i + 1 < host_offsets.size(); ++i) | ||
{ | ||
std::cout << "Results for query " << i << " on MPI rank " << comm_rank | ||
<< '\n'; | ||
for (int j = host_offsets(i); j < host_offsets(i + 1); ++j) | ||
std::cout << "point " << host_values(j).index << ", rank " | ||
<< host_values(j).rank << std::endl; | ||
} | ||
} | ||
Kokkos::finalize(); | ||
MPI_Finalize(); | ||
return 0; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,6 +160,15 @@ struct DistributedTreeImpl | |
Indices &indices, Offset &offset, Ranks &ranks, | ||
Distances *distances_ptr = nullptr); | ||
|
||
template <typename DistributedTree, typename ExecutionSpace, | ||
typename Predicates, typename OutputView, typename OffsetView, | ||
typename Callback> | ||
static std::enable_if_t<Kokkos::is_view<OutputView>{} && | ||
Kokkos::is_view<OffsetView>{}> | ||
queryDispatch(NearestPredicateTag, DistributedTree const &tree, | ||
ExecutionSpace const &space, Predicates const &queries, | ||
Callback const &callback, OutputView &out, OffsetView &offset); | ||
|
||
template <typename DistributedTree, typename ExecutionSpace, | ||
typename Predicates, typename IndicesAndRanks, typename Offset> | ||
static std::enable_if_t<Kokkos::is_view<IndicesAndRanks>{} && | ||
|
@@ -168,7 +177,6 @@ struct DistributedTreeImpl | |
ExecutionSpace const &space, Predicates const &queries, | ||
IndicesAndRanks &values, Offset &offset) | ||
{ | ||
// FIXME avoid zipping when distributed nearest callbacks become available | ||
Kokkos::View<int *, ExecutionSpace> indices( | ||
"ArborX::DistributedTree::query::nearest::indices", 0); | ||
Kokkos::View<int *, ExecutionSpace> ranks( | ||
|
@@ -309,8 +317,8 @@ void DistributedTreeImpl<DeviceType>::deviseStrategy( | |
|
||
// Accumulate total leave count in the local trees until it reaches k which | ||
// is the number of neighbors queried for. Stop if local trees get | ||
// empty because it means that they are no more leaves and there is no point | ||
// on forwarding queries to leafless trees. | ||
// empty because that means that there are no more leaves and there is no | ||
// point in forwarding queries to leafless trees. | ||
using Access = AccessTraits<Predicates, PredicatesTag>; | ||
auto const n_queries = Access::size(queries); | ||
Kokkos::View<int *, DeviceType> new_offset( | ||
|
@@ -627,6 +635,149 @@ DistributedTreeImpl<DeviceType>::queryDispatch( | |
Kokkos::Profiling::popRegion(); | ||
} | ||
|
||
template <typename Query> | ||
struct QueriesWithIndices | ||
{ | ||
Query query; | ||
int query_id; | ||
int primitive_index; | ||
}; | ||
|
||
template <typename DeviceType> | ||
template <typename DistributedTree, typename ExecutionSpace, | ||
typename Predicates, typename OutputView, typename OffsetView, | ||
typename Callback> | ||
std::enable_if_t<Kokkos::is_view<OutputView>{} && Kokkos::is_view<OffsetView>{}> | ||
DistributedTreeImpl<DeviceType>::queryDispatch( | ||
NearestPredicateTag, DistributedTree const &tree, | ||
ExecutionSpace const &space, Predicates const &queries, | ||
Callback const &callback, OutputView &out, OffsetView &offset) | ||
{ | ||
Kokkos::Profiling::pushRegion( | ||
"ArborX::DistributedTree::query::nearest_callback"); | ||
Kokkos::View<int *, ExecutionSpace> indices( | ||
"ArborX::DistributedTree::query::nearest::indices", 0); | ||
Kokkos::View<int *, ExecutionSpace> ranks( | ||
"ArborX::DistributedTree::query::nearest::ranks", 0); | ||
|
||
// Distributed nearest callbacks strategy: | ||
// - Find the ranks and indices for the nearest queries using a regular query | ||
// without a callback. | ||
// - Scatter (predicate, primitive) pairs to the corresponding matching ranks. | ||
// - Execute the callback on the process owning the primitives. | ||
// - Send the result back to the process owning the predicates. | ||
|
||
// Find the ranks and indices for the nearest queries using the overload not | ||
// taking a callback. | ||
queryDispatchImpl(NearestPredicateTag{}, tree, space, queries, indices, | ||
offset, ranks); | ||
Kokkos::Profiling::popRegion(); | ||
|
||
Kokkos::Profiling::pushRegion( | ||
"ArborX::DistributedTree::query::nearest::execute_callback"); | ||
|
||
// Send the predicate-primitive pairs to the process where the match was | ||
// found. | ||
auto comm = tree.getComm(); | ||
int comm_rank; | ||
MPI_Comm_rank(comm, &comm_rank); | ||
|
||
using Access = AccessTraits<Predicates, PredicatesTag>; | ||
using Query = typename AccessTraitsHelper<Access>::type; | ||
|
||
Kokkos::View<QueriesWithIndices<Query> *, typename DeviceType::memory_space> | ||
exported_queries_with_indices( | ||
Kokkos::view_alloc( | ||
space, Kokkos::WithoutInitializing, | ||
"ArborX::DistributedTree::query::exported_queries_with_indices"), | ||
ranks.size()); | ||
Kokkos::parallel_for( | ||
"ArborX::DistributedTree::query::zip_queries_and_primitives", | ||
Kokkos::RangePolicy<ExecutionSpace>(space, 0, Access::size(queries)), | ||
KOKKOS_LAMBDA(int q) { | ||
using index_type = typename OffsetView::value_type; | ||
for (index_type i = offset(q); i < offset(q + 1); ++i) | ||
exported_queries_with_indices(i) = {Access::get(queries, q), q, | ||
indices(i)}; | ||
}); | ||
|
||
Distributor<DeviceType> distributor(comm); | ||
auto const n_imports = distributor.createFromSends(space, ranks); | ||
|
||
Kokkos::View<QueriesWithIndices<Query> *, typename DeviceType::memory_space> | ||
imported_queries_with_indices( | ||
Kokkos::view_alloc( | ||
space, Kokkos::WithoutInitializing, | ||
"ArborX::DistributedTree::query::imported_queries_with_indices"), | ||
n_imports); | ||
|
||
sendAcrossNetwork(space, distributor, exported_queries_with_indices, | ||
imported_queries_with_indices); | ||
|
||
// Execute the callback on the process owning the primitives. | ||
OutputView remote_out( | ||
Kokkos::view_alloc(space, Kokkos::WithoutInitializing, | ||
"ArborX::DistributedTree::query::remote_out"), | ||
n_imports); | ||
KokkosExt::reallocWithoutInitializing(space, indices, n_imports); | ||
Kokkos::deep_copy(space, indices, -1); | ||
Kokkos::parallel_for( | ||
"ArborX::DistributedTree::query::execute_callbacks", | ||
Kokkos::RangePolicy<ExecutionSpace>(space, 0, | ||
imported_queries_with_indices.size()), | ||
KOKKOS_LAMBDA(int i) { | ||
callback(imported_queries_with_indices(i).query, | ||
imported_queries_with_indices(i).primitive_index, | ||
[&](typename OutputView::value_type const &value) { | ||
#ifndef NDEBUG | ||
// FIXME We only allow calling the callback once per match. | ||
if (indices(i) != -1) | ||
Kokkos::abort("Inserting more than one result per " | ||
"callback is not implemented!"); | ||
#endif | ||
remote_out(i) = value; | ||
indices(i) = imported_queries_with_indices(i).query_id; | ||
}); | ||
}); | ||
|
||
// Send the result back to the process owning the predicates. | ||
Distributor<DeviceType> back_distributor(comm); | ||
auto const &dest = distributor.getSources(); | ||
auto const &off = distributor.getSourceOffsets(); | ||
|
||
Kokkos::View<int const *, Kokkos::HostSpace> host_destinations(dest.data(), | ||
dest.size()); | ||
Kokkos::View<int const *, Kokkos::HostSpace> host_offsets(off.data(), | ||
off.size()); | ||
typename DeviceType::memory_space memory_space; | ||
Kokkos::View<int *, DeviceType> destinations( | ||
Kokkos::view_alloc(space, Kokkos::WithoutInitializing, | ||
"ArborX::DistributedTree::query::destinations"), | ||
dest.size()); | ||
Kokkos::deep_copy(space, destinations, host_destinations); | ||
Kokkos::View<int *, DeviceType> offsets( | ||
Kokkos::view_alloc(space, Kokkos::WithoutInitializing, | ||
"ArborX::DistributedTree::query::offsets"), | ||
off.size()); | ||
Kokkos::deep_copy(space, offsets, host_offsets); | ||
auto const n_imports_back = | ||
back_distributor.createFromSends(space, destinations, offsets); | ||
KokkosExt::reallocWithoutInitializing(space, out, n_imports_back); | ||
Kokkos::View<int *, DeviceType> query_ids( | ||
Kokkos::view_alloc(space, Kokkos::WithoutInitializing, | ||
"ArborX::DistributedTree::query::nearest::query_ids"), | ||
n_imports_back); | ||
|
||
Comment on lines
+752
to
+770
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be simplified by using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't that what I'm doing already? |
||
// FIXME does combining communication here help? | ||
sendAcrossNetwork(space, back_distributor, remote_out, out); | ||
sendAcrossNetwork(space, back_distributor, indices, query_ids); | ||
|
||
auto const permutation = ArborX::Details::sortObjects(space, query_ids); | ||
ArborX::Details::applyPermutation(space, permutation, out); | ||
|
||
Kokkos::Profiling::popRegion(); | ||
} | ||
|
||
template <typename DeviceType> | ||
template <typename ExecutionSpace, typename View, typename... OtherViews> | ||
void DistributedTreeImpl<DeviceType>::sortResults(ExecutionSpace const &space, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is deprecated, need to use the one that returns pairs of (index, rank).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the
Impl
versionqueryDispatchImpl
is not deprecated (and I rather use this one internally).