Skip to content

Commit

Permalink
bump version to 0.0.15.post1 & refactor cutlass code & build for torc…
Browse files Browse the repository at this point in the history
…h 2.1.2
  • Loading branch information
chengzeyi committed Dec 18, 2023
1 parent 1f46389 commit 6885350
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 79 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ jobs:
- "3.10"
- "3.11"
torch_version:
- "2.1.0"
- "2.1.1"
- "2.1.2"
cuda_short_version:
- "118"
- "121"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_extensions():
if cuda_version >= 1102:
extra_compile_args["nvcc"] += [
"--threads",
"4",
"2",
"--ptxas-options=-v",
]
if platform.system() == "Windows":
Expand Down
2 changes: 1 addition & 1 deletion src/sfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def new_lru_cache(*args, **kwargs):

# This line will be programatically read/write by setup.py.
# Leave them at the bottom of this file and don't touch them.
__version__ = "0.0.15"
__version__ = "0.0.15.post1"
17 changes: 9 additions & 8 deletions src/sfast/csrc/operators/cutlass/cutlass_dual_linear_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,6 @@ torch::Tensor cutlass_dual_gemm(
ElementComputeEpilogue(
bias0.has_value() ? 1.0 : 0.0)}, // <- tuple of alpha and beta
epilogue2_params};
// Allocate workspace memory
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace =
torch::empty({static_cast<int64_t>(workspace_size)},
torch::dtype(torch::kUInt8).device(input.device()));

torch::DeviceGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

cutlass::Status status;
Gemm gemm_op;
Expand All @@ -232,11 +224,20 @@ torch::Tensor cutlass_dual_gemm(
"This problem size is not supported by this Gemm implementation: ",
cutlass::cutlassGetStatusString(status));

// Allocate workspace memory
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace =
torch::empty({static_cast<int64_t>(workspace_size)},
torch::dtype(torch::kUInt8).device(input.device()));

status = gemm_op.initialize(arguments, workspace.data_ptr<uint8_t>());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize cutlass gemm: ",
cutlass::cutlassGetStatusString(status));

torch::DeviceGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

status = gemm_op(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to execute cutlass gemm: ",
Expand Down
100 changes: 33 additions & 67 deletions src/sfast/csrc/operators/cutlass/cutlass_qlinear_dynamic_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,36 @@ namespace sm80_space {
using SmArch = cutlass::arch::Sm80;
constexpr int NumStages = 4;

template <typename scalar_t, typename acc_t> struct GemmWrapper {
template <typename scalar_t, typename acc_t> struct GemmConfig {
using ElementA = scalar_t;
using ElementB = int8_t;
using ElementOutput = scalar_t;
using ElementAccumulator = acc_t;
using ElementComputeEpilogue = acc_t;

using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA, LayoutInputA, ElementB, LayoutInputB, ElementOutput,
LayoutOutput, ElementAccumulator, MMAOp, SmArch,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, NumStages,
128 / cutlass::sizeof_bits<ElementA>::value,
128 / cutlass::sizeof_bits<ElementB>::value,
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone>;
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
};
} // namespace sm80_space

using GemmNoBias = cutlass::gemm::device::GemmUniversal<
ElementA, LayoutInputA, ElementB, LayoutInputB, ElementOutput,
LayoutOutput, ElementAccumulator, MMAOp, SmArch,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, NumStages,
128 / cutlass::sizeof_bits<ElementA>::value,
128 / cutlass::sizeof_bits<ElementB>::value,
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone>;
using namespace sm80_space;

template <typename config> struct GemmWrapper {
using ElementA = typename config::ElementA;
using ElementB = typename config::ElementB;
using ElementOutput = typename config::ElementOutput;
using ElementAccumulator = typename config::ElementAccumulator;
using ElementComputeEpilogue = typename config::ElementComputeEpilogue;

using GemmSmall = cutlass::gemm::device::GemmUniversal<
using ThreadBlockShape = typename config::ThreadBlockShape;
using WarpShape = typename config::WarpShape;
using InstructionShape = typename config::InstructionShape;

using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA, LayoutInputA, ElementB, LayoutInputB, ElementOutput,
LayoutOutput, ElementAccumulator, MMAOp, SmArch,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
ThreadBlockShape, WarpShape, InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue,
Expand All @@ -80,11 +68,10 @@ template <typename scalar_t, typename acc_t> struct GemmWrapper {
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone>;

using GemmNoBiasSmall = cutlass::gemm::device::GemmUniversal<
using GemmNoBias = cutlass::gemm::device::GemmUniversal<
ElementA, LayoutInputA, ElementB, LayoutInputB, ElementOutput,
LayoutOutput, ElementAccumulator, MMAOp, SmArch,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
ThreadBlockShape, WarpShape, InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue,
Expand All @@ -95,9 +82,6 @@ template <typename scalar_t, typename acc_t> struct GemmWrapper {
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone>;
};
} // namespace sm80_space

using namespace sm80_space;

void get_input_layout(const torch::Tensor &input, const torch::Tensor &weight,
int &B, int &M, int &K, int &N,
Expand Down Expand Up @@ -185,14 +169,6 @@ cutlass_gemm(const torch::Tensor &input, const torch::Tensor &weight,
weight_ref.stride(0),
bias_ref.stride(0),
output_ref.stride(0)};
// Allocate workspace memory
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace =
torch::empty({static_cast<int64_t>(workspace_size)},
torch::dtype(torch::kUInt8).device(input.device()));

torch::DeviceGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

cutlass::Status status;
Gemm gemm_op;
Expand All @@ -203,11 +179,20 @@ cutlass_gemm(const torch::Tensor &input, const torch::Tensor &weight,
"This problem size is not supported by this Gemm implementation: ",
cutlass::cutlassGetStatusString(status));

// Allocate workspace memory
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace =
torch::empty({static_cast<int64_t>(workspace_size)},
torch::dtype(torch::kUInt8).device(input.device()));

status = gemm_op.initialize(arguments, workspace.data_ptr<uint8_t>());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize cutlass gemm: ",
cutlass::cutlassGetStatusString(status));

torch::DeviceGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

status = gemm_op(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to execute cutlass gemm: ",
Expand Down Expand Up @@ -235,54 +220,35 @@ template <> struct acc_type<cutlass::bfloat16_t> { using type = float; };
template <typename at_type> struct CutlassGemmLauncher {
using scalar_t = typename cutlass_type<at_type>::type;
using acc_t = typename acc_type<scalar_t>::type;
using GemmWrapper_ = GemmWrapper<scalar_t, acc_t>;
using GemmWrapper_ = GemmWrapper<GemmConfig<scalar_t, acc_t>>;
using Gemm = typename GemmWrapper_::Gemm;
using GemmNoBias = typename GemmWrapper_::GemmNoBias;
using GemmSmall = typename GemmWrapper_::GemmSmall;
using GemmNoBiasSmall = typename GemmWrapper_::GemmNoBiasSmall;

static torch::Tensor launch(const torch::Tensor &input,
const torch::Tensor &weight,
const c10::optional<torch::Tensor> &bias,
float dq_scale) {
auto N = weight.size(0);
auto K = weight.size(1);
auto M = input.numel() / K;

bool use_small_kernel = M <= Gemm::ThreadblockShape::kM ||
N <= Gemm::ThreadblockShape::kN ||
K <= Gemm::ThreadblockShape::kK;
// auto M = input.numel() / K;

if (K % Gemm::kAlignmentA != 0 || K % Gemm::kAlignmentB != 0 ||
N % Gemm::kAlignmentC != 0) {
if (K % GemmSmall::kAlignmentA != 0 || K % GemmSmall::kAlignmentB != 0 ||
N % GemmSmall::kAlignmentC != 0) {
auto weight_ = input.scalar_type() == at::kFloat
? weight.dequantize()
: weight.int_repr()
.to(input.scalar_type())
.mul_(weight.q_scale());
return cublas_lowp_linear(input, weight_, bias);
} else {
use_small_kernel = true;
}
}
auto input_ = input.contiguous();
auto weight_ = weight.contiguous();
if (bias.has_value()) {
c10::optional<torch::Tensor> bias_;
bias_.emplace(bias.value().contiguous());
if (use_small_kernel) {
return cutlass_gemm<GemmSmall>(input_, weight_, bias_, dq_scale);
} else {
return cutlass_gemm<Gemm>(input_, weight_, bias_, dq_scale);
}
} else {
if (use_small_kernel) {
return cutlass_gemm<GemmNoBiasSmall>(input_, weight_, bias, dq_scale);
} else {
return cutlass_gemm<GemmNoBias>(input_, weight_, bias, dq_scale);
}
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.15
0.0.15.post1

0 comments on commit 6885350

Please sign in to comment.