Skip to content

Commit

Permalink
Add SparseHalo creation function, match Halo constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
streeve committed Aug 8, 2023
1 parent 3bf1cdf commit 9801705
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 27 deletions.
2 changes: 1 addition & 1 deletion cajita/src/Cajita_Halo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ struct ArrayPackMemorySpace

//---------------------------------------------------------------------------//
/*!
\brief Array creation function.
\brief Halo creation function.
\param pattern The pattern to build the halo from.
\param width Must be less than or equal to the width of the array halo.
\param arrays The arrays over which to build the halo.
Expand Down
70 changes: 49 additions & 21 deletions cajita/src/Cajita_SparseHalo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace Experimental
communicator and halo size. The arrays must also reside in the same memory
space. These requirements are checked at construction.
*/
template <class MemorySpace, class DataMembers, class EntityType,
template <class MemorySpace, class DataTypes, class EntityType,
std::size_t NumSpaceDim, unsigned long long cellBitsPerTileDim,
typename Value = int, typename Key = uint64_t>
class SparseHalo
Expand Down Expand Up @@ -71,7 +71,7 @@ class SparseHalo
};

//! data members in AoSoA structure
using aosoa_member_types = DataMembers;
using aosoa_member_types = DataTypes;
//! AoSoA tuple type
using tuple_type = Cabana::Tuple<aosoa_member_types>;

Expand Down Expand Up @@ -114,13 +114,11 @@ class SparseHalo
\brief constructor
\tparam LocalGridType local grid type
\param pattern The halo pattern to use for halo communication
\param local_grid_ptr pointer to sparse local grid
\param comm MPI communicator
\param sparse_array Sparse array to communicate
*/
template <class LocalGridType>
template <class SparseArrayType>
SparseHalo( halo_pattern_type pattern,
const std::shared_ptr<LocalGridType>& local_grid_ptr,
MPI_Comm comm )
const std::shared_ptr<SparseArrayType>& sparse_array )
: _pattern( pattern )
{
// Function to get the local id of the neighbor.
Expand Down Expand Up @@ -151,17 +149,20 @@ class SparseHalo
std::accumulate( soa_byte_array.begin(), soa_byte_array.end(), 0 ),
static_cast<int>( sizeof( tuple_type ) ) );

// Get the local grid the array uses.
auto local_grid = sparse_array->layout().localGrid();

// linear MPI rank ID of the current working rank
_self_rank =
local_grid_ptr->neighborRank( std::array<int, 3>( { 0, 0, 0 } ) );
local_grid->neighborRank( std::array<int, 3>( { 0, 0, 0 } ) );

// set the linear neighbor rank ID
// set up correspondence between sending and receiving buffers
auto neighbors = _pattern.getNeighbors();
for ( const auto& n : neighbors )
{
// neighbor rank linear ID
int rank = local_grid_ptr->neighborRank( n );
int rank = local_grid->neighborRank( n );

// if neighbor is valid
if ( rank >= 0 )
Expand All @@ -176,10 +177,10 @@ class SparseHalo
_receive_tags.push_back( neighbor_id( flip_id( n ) ) );

// build communication data for owned entries
buildCommData( Own(), local_grid_ptr, n, _owned_buffers,
buildCommData( Own(), local_grid, n, _owned_buffers,
_owned_tile_steering, _owned_tile_spaces );
// build communication data for ghosted entries
buildCommData( Ghost(), local_grid_ptr, n, _ghosted_buffers,
buildCommData( Ghost(), local_grid, n, _ghosted_buffers,
_ghosted_tile_steering, _ghosted_tile_spaces );

auto& own_index_space = _owned_tile_spaces.back();
Expand Down Expand Up @@ -215,25 +216,24 @@ class SparseHalo
\tparam DecompositionTag decomposition tag type
\tparam LocalGridType sparse local grid type
\param decomposition_tag tag to indicate if it's owned or ghosted halo
\param local_grid_ptr sparse local grid shared pointer
\param local_grid sparse local grid shared pointer
\param nid neighbor local id (ijk in pattern)
\param buffers buffer to be used to store communicated data
\param steering steering to be used to guide communications
\param spaces sparse tile index spaces
*/
template <class DecompositionTag, class LocalGridType>
void buildCommData( DecompositionTag decomposition_tag,
const std::shared_ptr<LocalGridType>& local_grid_ptr,
const std::shared_ptr<LocalGridType>& local_grid,
const std::array<int, num_space_dim>& nid,
std::vector<buffer_view>& buffers,
std::vector<steering_view>& steering,
std::vector<tile_index_space>& spaces )
{
// get the halo sparse tile index space sharsed with the neighbor
spaces.push_back(
local_grid_ptr
->template sharedTileIndexSpace<cell_bits_per_tile_dim>(
decomposition_tag, entity_type(), nid ) );
local_grid->template sharedTileIndexSpace<cell_bits_per_tile_dim>(
decomposition_tag, entity_type(), nid ) );
auto& index_space = spaces.back();

// allocate the buffer to store shared data with given neighbor
Expand All @@ -251,10 +251,10 @@ class SparseHalo
/*!
\brief update tile index space according to current partition
\tparam LocalGridType sparse local grid type
\param local_grid_ptr sparse local grid pointer
\param local_grid sparse local grid pointer
*/
template <class LocalGridType>
void updateTileSpace( const std::shared_ptr<LocalGridType>& local_grid_ptr )
void updateTileSpace( const std::shared_ptr<LocalGridType>& local_grid )
{
// clear index space array first
_owned_tile_spaces.clear();
Expand All @@ -267,19 +267,19 @@ class SparseHalo
// get neighbor relative id
auto& n = _valid_neighbor_ids[i];
// get neighbor linear MPI rank ID
int rank = local_grid_ptr->neighborRank( n );
int rank = local_grid->neighborRank( n );
// check if neighbor rank is valid
// the neighbor id should always be valid (as all should be
// well-prepared during construction/initialization)
if ( rank == _neighbor_ranks[i] )
{
// get shared tile index spcae from local grid
_owned_tile_spaces.push_back(
local_grid_ptr
local_grid
->template sharedTileIndexSpace<cell_bits_per_tile_dim>(
Own(), entity_type(), n ) );
_ghosted_tile_spaces.push_back(
local_grid_ptr
local_grid
->template sharedTileIndexSpace<cell_bits_per_tile_dim>(
Ghost(), entity_type(), n ) );

Expand Down Expand Up @@ -1283,6 +1283,34 @@ class SparseHalo
// SoA total bytes count
std::size_t _soa_total_bytes;
};

//---------------------------------------------------------------------------//
// Sparse halo creation.
//---------------------------------------------------------------------------//
/*!
\brief SparseHalo creation function.
\param pattern The pattern to build the sparse halo from.
\param array The sparse array over which to build the halo.
*/
template <class DeviceType, unsigned long long cellBitsPerTileDim,
class DataTypes, class EntityType, class MeshType,
class SparseMapType, class Pattern, typename Value = int,
typename Key = uint64_t>
auto createSparseHalo(
const Pattern& pattern,
const std::shared_ptr<
SparseArray<DataTypes, DeviceType, EntityType, MeshType, SparseMapType>>
array )
{
using array_type =
SparseArray<DataTypes, DeviceType, EntityType, MeshType, SparseMapType>;
using memory_space = typename array_type::memory_space;
static constexpr std::size_t num_space_dim = array_type::num_space_dim;
return std::make_shared<
SparseHalo<memory_space, DataTypes, EntityType, num_space_dim,
cellBitsPerTileDim, Value, Key>>( pattern, array );
}

}; // namespace Experimental
}; // end namespace Cajita

Expand Down
10 changes: 5 additions & 5 deletions cajita/unit_test/tstSparseHalo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ void haloScatterAndGatherTest( ReduceOp reduce_op, EntityType entity )
auto sparse_array = createSparseArray<TEST_DEVICE>(
std::string( "test_sparse_grid" ), *sparse_layout );

SparseHalo<TEST_MEMSPACE, DataTypes, EntityType, 3, cell_bits_per_tile_dim>
halo( NodeHaloPattern<3>(), local_grid, MPI_COMM_WORLD );
auto halo = createSparseHalo<TEST_DEVICE, cell_bits_per_tile_dim>(
NodeHaloPattern<3>(), sparse_array );

// sample valid halos on rank 0 and broadcast to other ranks
// Kokkos::View<T* [3], TEST_MEMSPACE> tile_view;
Expand Down Expand Up @@ -517,7 +517,7 @@ void haloScatterAndGatherTest( ReduceOp reduce_op, EntityType entity )
} );

sparse_array->resize( sparse_map.sizeCell() );
halo.template register_halo<TEST_EXECSPACE>( sparse_map );
halo->template register_halo<TEST_EXECSPACE>( sparse_map );
MPI_Barrier( MPI_COMM_WORLD );
}

Expand Down Expand Up @@ -605,10 +605,10 @@ void haloScatterAndGatherTest( ReduceOp reduce_op, EntityType entity )
// halo scatter and gather
/// false means the heighbors' halo counting information is not
/// collected
halo.scatter( TEST_EXECSPACE(), reduce_op, *sparse_array, false );
halo->scatter( TEST_EXECSPACE(), reduce_op, *sparse_array, false );
/// halo counting info already collected in the previous scatter, thus true
/// and no need to recount again
halo.gather( TEST_EXECSPACE(), *sparse_array, true );
halo->gather( TEST_EXECSPACE(), *sparse_array, true );
MPI_Barrier( MPI_COMM_WORLD );

// check results
Expand Down

0 comments on commit 9801705

Please sign in to comment.