From ff991e61dca9132281877661d27dca77904d6866 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 24 Oct 2024 17:41:48 +0200 Subject: [PATCH] fix: use reshape instead of view Most of the time, we want the resulting Tensor to be contiguous, so it is better to use reshape to enforce it. --- optimum/quanto/library/extensions/cuda/__init__.py | 4 ++-- optimum/quanto/library/qbytes_mm.py | 8 ++++---- optimum/quanto/nn/qlayernorm.py | 3 ++- optimum/quanto/nn/qmodule.py | 5 +++-- optimum/quanto/tensor/weights/awq/packed.py | 4 ++-- optimum/quanto/tensor/weights/marlin/fp8/qbits.py | 2 +- optimum/quanto/tensor/weights/packing.py | 2 +- optimum/quanto/tensor/weights/qbytes.py | 4 ++-- optimum/quanto/tensor/weights/tinygemm/qbits.py | 4 ++-- 9 files changed, 19 insertions(+), 17 deletions(-) diff --git a/optimum/quanto/library/extensions/cuda/__init__.py b/optimum/quanto/library/extensions/cuda/__init__.py index bd164060..7ba29365 100644 --- a/optimum/quanto/library/extensions/cuda/__init__.py +++ b/optimum/quanto/library/extensions/cuda/__init__.py @@ -188,9 +188,9 @@ def gemm_f16i4_marlin( device=input.device, ) ext.lib.marlin_gemm_f16i4( - input.view((-1, input.shape[-1])), + input.reshape((-1, input.shape[-1])), other, - output.view((-1, output.shape[-1])), + output.reshape((-1, output.shape[-1])), scales, zeropoint, workspace, diff --git a/optimum/quanto/library/qbytes_mm.py b/optimum/quanto/library/qbytes_mm.py index 0950cb65..5127b588 100644 --- a/optimum/quanto/library/qbytes_mm.py +++ b/optimum/quanto/library/qbytes_mm.py @@ -42,8 +42,8 @@ def qbytes_int_mm(activations: torch.Tensor, weights: torch.Tensor, output_scale out_data = torch._int_mm(activations, weights) else: output_shape = activations.shape[:-1] + (out_features,) - out_data = torch._int_mm(activations.view(-1, in_features), weights) - out_data = out_data.view(output_shape) + out_data = torch._int_mm(activations.reshape(-1, in_features), weights) + out_data = out_data.reshape(output_shape) # We must evaluate the output as float32 because the multiplication # of the int32 data by the scales might overflow fp32_output = out_data.to(torch.float32) * output_scales.t() @@ -59,8 +59,8 @@ def qbytes_int8pack_mm(activations: torch.Tensor, weights: torch.Tensor, output_ in_features = activations.shape[-1] out_features = weights.shape[0] output_shape = activations.shape[:-1] + (out_features,) - out_data = torch._weight_int8pack_mm(activations.view(-1, in_features), weights, output_scales) - return out_data.view(output_shape) + out_data = torch._weight_int8pack_mm(activations.reshape(-1, in_features), weights, output_scales) + return out_data.reshape(output_shape) @torch.library.impl("quanto::qbytes_mm", "default") diff --git a/optimum/quanto/nn/qlayernorm.py b/optimum/quanto/nn/qlayernorm.py index 95d7e911..c15db6ce 100644 --- a/optimum/quanto/nn/qlayernorm.py +++ b/optimum/quanto/nn/qlayernorm.py @@ -36,12 +36,13 @@ def qcreate( ): if activations is None: return None + dtype = None if module.weight is None else module.weight.dtype return cls( module.normalized_shape, module.eps, module.elementwise_affine, module.bias is not None, - dtype=module.weight.dtype, + dtype=dtype, device=device, weights=None, # We never quantize QLayerNorm weights activations=activations, diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index f80e5df5..152f9d17 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -136,8 +136,9 @@ def __init__( if optimizer is None and self.weight_qtype is not None: optimizer = AbsmaxOptimizer() if self.weight_qtype.bits == 8 else MaxOptimizer() self.optimizer = optimizer - self.register_buffer("input_scale", torch.ones((), dtype=self.weight.dtype, device=device)) - self.register_buffer("output_scale", torch.ones((), dtype=self.weight.dtype, device=device)) + scale_dtype = torch.float32 if self.weight is None else self.weight.dtype + self.register_buffer("input_scale", torch.ones((), dtype=scale_dtype, device=device)) + self.register_buffer("output_scale", torch.ones((), dtype=scale_dtype, device=device)) def disable_output_quantization(self): if "output" in self._quantize_hooks: diff --git a/optimum/quanto/tensor/weights/awq/packed.py b/optimum/quanto/tensor/weights/awq/packed.py index 8232d5c1..bc1af636 100644 --- a/optimum/quanto/tensor/weights/awq/packed.py +++ b/optimum/quanto/tensor/weights/awq/packed.py @@ -68,9 +68,9 @@ def reverse_awq_order(t: torch.Tensor): dtype=torch.int32, device=t.device, ) - reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor.reshape(-1, 32 // bits) reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] - reverse_order_tensor = reverse_order_tensor.view(-1) + reverse_order_tensor = reverse_order_tensor.reshape(-1) t = t[:, reverse_order_tensor] diff --git a/optimum/quanto/tensor/weights/marlin/fp8/qbits.py b/optimum/quanto/tensor/weights/marlin/fp8/qbits.py index 0cf24243..3eb3a653 100644 --- a/optimum/quanto/tensor/weights/marlin/fp8/qbits.py +++ b/optimum/quanto/tensor/weights/marlin/fp8/qbits.py @@ -32,7 +32,7 @@ def forward(ctx, input, other, bias=None): input_shape = input.shape if input.ndim > 2: - input = input.view(-1, input_shape[-1]) + input = input.reshape(-1, input_shape[-1]) output = torch.ops.quanto.gemm_f16f8_marlin( input, diff --git a/optimum/quanto/tensor/weights/packing.py b/optimum/quanto/tensor/weights/packing.py index fc59015a..15be7e06 100644 --- a/optimum/quanto/tensor/weights/packing.py +++ b/optimum/quanto/tensor/weights/packing.py @@ -34,7 +34,7 @@ def unpack_int32_to_uint8(packed: torch.Tensor, bits: int): unpacked = torch.bitwise_right_shift(packed[:, :, None], shifts[None, None, :]).to( torch.int8 # smallest dtype available ) - unpacked = unpacked.view(unpacked.shape[0], -1) + unpacked = unpacked.reshape(unpacked.shape[0], -1) # Convert to unsigned unpacked = torch.bitwise_and(unpacked, (2**bits) - 1) diff --git a/optimum/quanto/tensor/weights/qbytes.py b/optimum/quanto/tensor/weights/qbytes.py index 68d0f65b..d2dd50a7 100644 --- a/optimum/quanto/tensor/weights/qbytes.py +++ b/optimum/quanto/tensor/weights/qbytes.py @@ -74,8 +74,8 @@ def forward(ctx, input, other, bias=None): in_features = input.shape[-1] out_features = other.shape[0] output_shape = input.shape[:-1] + (out_features,) - output = torch.ops.quanto.qbytes_mm(input.view(-1, in_features), other._data, other._scale) - output = output.view(output_shape) + output = torch.ops.quanto.qbytes_mm(input.reshape(-1, in_features), other._data, other._scale) + output = output.reshape(output_shape) if bias is not None: output = output + bias return output diff --git a/optimum/quanto/tensor/weights/tinygemm/qbits.py b/optimum/quanto/tensor/weights/tinygemm/qbits.py index 7b11cc8c..4bedec99 100644 --- a/optimum/quanto/tensor/weights/tinygemm/qbits.py +++ b/optimum/quanto/tensor/weights/tinygemm/qbits.py @@ -49,9 +49,9 @@ def forward(ctx, input, other, bias): out_features = other.shape[0] output_shape = input.shape[:-1] + (out_features,) output = torch._weight_int4pack_mm( - input.view(-1, in_features), other._data._data, other._group_size, other._scale_shift + input.reshape(-1, in_features), other._data._data, other._group_size, other._scale_shift ) - output = output.view(output_shape) + output = output.reshape(output_shape) if bias is not None: output = output + bias return output