Skip to content

Commit

Permalink
Refactor to use a struct to hold factor index and axis refs.
Browse files Browse the repository at this point in the history
This simplifies the logic a bit, and aligns better with the algorithm which picks the factor-axes pair with the largest count at each iteration. Also prepares a different implementation that keeps a mapvector of factor-axes pairs and treverses the vector in the non-increasing order of their counts.

PiperOrigin-RevId: 696112797
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Nov 13, 2024
1 parent f7a0c94 commit c7c3f39
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,16 @@ bool axisRefsOverlap(ArrayRef<AxisRefAttr> first,
return false;
}

struct FactorAxesPair {
int64_t factorIndex = -1;
ArrayRef<AxisRefAttr> axisRefs;

FactorAxesPair(int64_t factorIndex, ArrayRef<AxisRefAttr> axisRefs)
: factorIndex(factorIndex), axisRefs(axisRefs) {}

FactorAxesPair() = default;
};

// Broadly the algorithm is, at each iteration, to pick a {factor,axis} pair
// with the largest count from a list that is initialized with all the
// pairs with non-zero count, assign the picked axis to the picked factor, and
Expand All @@ -180,8 +190,7 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
SmallVector<DenseMap<ArrayRef<AxisRefAttr>, int64_t>> factorAxesCounts(
numFactors);
int64_t maxCount = 0;
int64_t bestFactorIndex;
ArrayRef<AxisRefAttr> bestAxisRefs;
FactorAxesPair bestFactorAxes;
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
projection.getResults())) {
Expand All @@ -194,8 +203,7 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
int64_t axesCount = ++factorAxesCounts[factorIndex][axisRefs];
if (axesCount > maxCount) {
maxCount = axesCount;
bestFactorIndex = factorIndex;
bestAxisRefs = axisRefs;
bestFactorAxes = FactorAxesPair(factorIndex, axisRefs);
}
}
}
Expand All @@ -209,22 +217,22 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
BitVector unseenFactors(numFactors, true);
// TODO(enver): Optimize to mark unseen only the factors with an axis.
while (maxCount > 0) {
factorAxisRefs[bestFactorIndex] = llvm::to_vector(bestAxisRefs);
unseenFactors.reset(bestFactorIndex);
factorAxisRefs[bestFactorAxes.factorIndex] =
llvm::to_vector(bestFactorAxes.axisRefs);
unseenFactors.reset(bestFactorAxes.factorIndex);
// TODO(enver): Tie-breaking currently depends on the order of iteration.
// Consider some heuristic for breaking ties.
// Invalidate axes that overlaps with the picked one across all unseen
// factors. During the iteration, also find the new best.
maxCount = 0;
int64_t nextBestFactorIndex;
ArrayRef<AxisRefAttr> nextBestAxisRefs;
FactorAxesPair nextBestFactorAxes;
for (int factorIndex : unseenFactors.set_bits()) {
auto& axesCounts = factorAxesCounts[factorIndex];
for (const auto& [axisRefs, count] : axesCounts) {
// TODO(enver): Relax the overlap check. We need to erase in case of an
// overlap only if the factor indices appear together in any of the
// operands or results.
if (axisRefsOverlap(bestAxisRefs, axisRefs)) {
if (axisRefsOverlap(bestFactorAxes.axisRefs, axisRefs)) {
// TODO(enver): Optimize to flip unseen if all the axes of the factor
// have zero count.
// Clear the count of overlapping axis, effectively erasing.
Expand All @@ -235,13 +243,11 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
}
if (count > maxCount) {
maxCount = count;
nextBestFactorIndex = factorIndex;
nextBestAxisRefs = axisRefs;
nextBestFactorAxes = FactorAxesPair(factorIndex, axisRefs);
}
}
}
bestFactorIndex = nextBestFactorIndex;
bestAxisRefs = nextBestAxisRefs;
bestFactorAxes = nextBestFactorAxes;
}
return factorAxisRefs;
}
Expand Down

0 comments on commit c7c3f39

Please sign in to comment.