Skip to content

Commit

Permalink
fix: use reshape instead of view
Browse files Browse the repository at this point in the history
Most of the time, we want the resulting Tensor to be contiguous, so it
is better to use reshape to enforce it.
  • Loading branch information
dacorvo committed Oct 24, 2024
1 parent 852bb9c commit ff991e6
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 17 deletions.
4 changes: 2 additions & 2 deletions optimum/quanto/library/extensions/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ def gemm_f16i4_marlin(
device=input.device,
)
ext.lib.marlin_gemm_f16i4(
input.view((-1, input.shape[-1])),
input.reshape((-1, input.shape[-1])),
other,
output.view((-1, output.shape[-1])),
output.reshape((-1, output.shape[-1])),
scales,
zeropoint,
workspace,
Expand Down
8 changes: 4 additions & 4 deletions optimum/quanto/library/qbytes_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def qbytes_int_mm(activations: torch.Tensor, weights: torch.Tensor, output_scale
out_data = torch._int_mm(activations, weights)
else:
output_shape = activations.shape[:-1] + (out_features,)
out_data = torch._int_mm(activations.view(-1, in_features), weights)
out_data = out_data.view(output_shape)
out_data = torch._int_mm(activations.reshape(-1, in_features), weights)
out_data = out_data.reshape(output_shape)
# We must evaluate the output as float32 because the multiplication
# of the int32 data by the scales might overflow
fp32_output = out_data.to(torch.float32) * output_scales.t()
Expand All @@ -59,8 +59,8 @@ def qbytes_int8pack_mm(activations: torch.Tensor, weights: torch.Tensor, output_
in_features = activations.shape[-1]
out_features = weights.shape[0]
output_shape = activations.shape[:-1] + (out_features,)
out_data = torch._weight_int8pack_mm(activations.view(-1, in_features), weights, output_scales)
return out_data.view(output_shape)
out_data = torch._weight_int8pack_mm(activations.reshape(-1, in_features), weights, output_scales)
return out_data.reshape(output_shape)


@torch.library.impl("quanto::qbytes_mm", "default")
Expand Down
3 changes: 2 additions & 1 deletion optimum/quanto/nn/qlayernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ def qcreate(
):
if activations is None:
return None
dtype = None if module.weight is None else module.weight.dtype
return cls(
module.normalized_shape,
module.eps,
module.elementwise_affine,
module.bias is not None,
dtype=module.weight.dtype,
dtype=dtype,
device=device,
weights=None, # We never quantize QLayerNorm weights
activations=activations,
Expand Down
5 changes: 3 additions & 2 deletions optimum/quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def __init__(
if optimizer is None and self.weight_qtype is not None:
optimizer = AbsmaxOptimizer() if self.weight_qtype.bits == 8 else MaxOptimizer()
self.optimizer = optimizer
self.register_buffer("input_scale", torch.ones((), dtype=self.weight.dtype, device=device))
self.register_buffer("output_scale", torch.ones((), dtype=self.weight.dtype, device=device))
scale_dtype = torch.float32 if self.weight is None else self.weight.dtype
self.register_buffer("input_scale", torch.ones((), dtype=scale_dtype, device=device))
self.register_buffer("output_scale", torch.ones((), dtype=scale_dtype, device=device))

def disable_output_quantization(self):
if "output" in self._quantize_hooks:
Expand Down
4 changes: 2 additions & 2 deletions optimum/quanto/tensor/weights/awq/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def reverse_awq_order(t: torch.Tensor):
dtype=torch.int32,
device=t.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor.reshape(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
reverse_order_tensor = reverse_order_tensor.reshape(-1)

t = t[:, reverse_order_tensor]

Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/tensor/weights/marlin/fp8/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def forward(ctx, input, other, bias=None):
input_shape = input.shape

if input.ndim > 2:
input = input.view(-1, input_shape[-1])
input = input.reshape(-1, input_shape[-1])

output = torch.ops.quanto.gemm_f16f8_marlin(
input,
Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/tensor/weights/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def unpack_int32_to_uint8(packed: torch.Tensor, bits: int):
unpacked = torch.bitwise_right_shift(packed[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
unpacked = unpacked.view(unpacked.shape[0], -1)
unpacked = unpacked.reshape(unpacked.shape[0], -1)

# Convert to unsigned
unpacked = torch.bitwise_and(unpacked, (2**bits) - 1)
Expand Down
4 changes: 2 additions & 2 deletions optimum/quanto/tensor/weights/qbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def forward(ctx, input, other, bias=None):
in_features = input.shape[-1]
out_features = other.shape[0]
output_shape = input.shape[:-1] + (out_features,)
output = torch.ops.quanto.qbytes_mm(input.view(-1, in_features), other._data, other._scale)
output = output.view(output_shape)
output = torch.ops.quanto.qbytes_mm(input.reshape(-1, in_features), other._data, other._scale)
output = output.reshape(output_shape)
if bias is not None:
output = output + bias
return output
Expand Down
4 changes: 2 additions & 2 deletions optimum/quanto/tensor/weights/tinygemm/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def forward(ctx, input, other, bias):
out_features = other.shape[0]
output_shape = input.shape[:-1] + (out_features,)
output = torch._weight_int4pack_mm(
input.view(-1, in_features), other._data._data, other._group_size, other._scale_shift
input.reshape(-1, in_features), other._data._data, other._group_size, other._scale_shift
)
output = output.view(output_shape)
output = output.reshape(output_shape)
if bias is not None:
output = output + bias
return output
Expand Down

0 comments on commit ff991e6

Please sign in to comment.