From 450870266df82a0c2ecf0d86c9fa738172763a5c Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Thu, 31 Oct 2024 23:55:05 -0700 Subject: [PATCH] enable on xpu --- .../weights/test_weight_qbits_tensor_dispatch.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/tensor/weights/test_weight_qbits_tensor_dispatch.py b/test/tensor/weights/test_weight_qbits_tensor_dispatch.py index a0050d24..0c5b0d2a 100644 --- a/test/tensor/weights/test_weight_qbits_tensor_dispatch.py +++ b/test/tensor/weights/test_weight_qbits_tensor_dispatch.py @@ -73,15 +73,20 @@ def test_weight_qbits_tensor_linear(dtype, batch_size, tokens, in_features, out_ check_weight_qtensor_linear(qbt, batch_size, tokens, use_bias) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is too slow on non-CUDA devices") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [16, 32, 48, 64]) @pytest.mark.parametrize("in_features", [1024, 4096, 16384]) @pytest.mark.parametrize("out_features", [1024, 2048, 4096]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -def test_weight_qbits_tensor_linear_cuda(dtype, batch_size, tokens, in_features, out_features, use_bias): - device = torch.device("cuda") +def test_weight_qbits_tensor_linear_gpu(dtype, batch_size, tokens, in_features, out_features, use_bias): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.xpu.is_available(): + device = torch.device("xpu") + else: + pytest.skip(reason="Test is too slow on non-GPU devices") + weight_qtype = qint4 group_size = 128 # Create a QBitsTensor