diff --git a/test/conftest.py b/test/conftest.py index d8bb234c..5e9b4367 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -21,6 +21,8 @@ devices += ["cuda"] elif torch.backends.mps.is_available(): devices += ["mps"] +elif torch.xpu.is_available(): + devices += ["xpu"] @pytest.fixture(scope="module", params=devices) diff --git a/test/helpers.py b/test/helpers.py index 1693d6a9..3e6635f3 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -103,6 +103,9 @@ def get_device_memory(device): elif device.type == "mps": torch.mps.empty_cache() return torch.mps.current_allocated_memory() + elif device.type == "xpu": + torch.xpu.empty_cache() + return torch.xpu.memory_allocated() return None diff --git a/test/tensor/weights/test_weight_qbits_tensor_dispatch.py b/test/tensor/weights/test_weight_qbits_tensor_dispatch.py index a0050d24..0c5b0d2a 100644 --- a/test/tensor/weights/test_weight_qbits_tensor_dispatch.py +++ b/test/tensor/weights/test_weight_qbits_tensor_dispatch.py @@ -73,15 +73,20 @@ def test_weight_qbits_tensor_linear(dtype, batch_size, tokens, in_features, out_ check_weight_qtensor_linear(qbt, batch_size, tokens, use_bias) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is too slow on non-CUDA devices") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [16, 32, 48, 64]) @pytest.mark.parametrize("in_features", [1024, 4096, 16384]) @pytest.mark.parametrize("out_features", [1024, 2048, 4096]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -def test_weight_qbits_tensor_linear_cuda(dtype, batch_size, tokens, in_features, out_features, use_bias): - device = torch.device("cuda") +def test_weight_qbits_tensor_linear_gpu(dtype, batch_size, tokens, in_features, out_features, use_bias): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.xpu.is_available(): + device = torch.device("xpu") + else: + pytest.skip(reason="Test is too slow on non-GPU devices") + weight_qtype = qint4 group_size = 128 # Create a QBitsTensor diff --git a/test/tensor/weights/weight_helpers.py b/test/tensor/weights/weight_helpers.py index 761cbea3..762836e7 100644 --- a/test/tensor/weights/weight_helpers.py +++ b/test/tensor/weights/weight_helpers.py @@ -31,7 +31,7 @@ def check_weight_qtensor_linear(qweight, batch_size, tokens, use_bias, rel_max_e max_err = (out - qout).abs().max() rel_max_err = max_err / mean_val # These values were evaluated empirically without any optimized kernels. - rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2}[device.type] + rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2, "xpu": 2e-2}[device.type] assert ( rel_max_err < rtol ), f"Maximum error {max_err:.2f} is too high for input of mean value {mean_val:.2f} ({rel_max_err*100:.2f} %)"