Skip to content

Commit

Permalink
Use fp32 for int16 simulation
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored and quic-akhobare committed Oct 31, 2023
1 parent 9f1da3a commit f961a96
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,14 @@ def calculate_forward_pass(tensor: torch.Tensor,
f"Got {tensor.dtype} input, {encoding_min.dtype} encoding_min, "
f"and {encoding_max.dtype} encoding_max")

if tensor_quantizer.bitwidth >= 32:
raise RuntimeError(f'Invalid bitwidth: {tensor_quantizer.bitwidth}')

orig_dtype = tensor.dtype
if tensor.dtype == torch.float16 and tensor_quantizer.bitwidth >= 16:
raise RuntimeError("torch.float16 does not provide sufficient precision for simulating "
"16-bit or higher integers. Please consider using torch.float32 arithmetic "
"or sub-16-bit quantization.")
tensor = tensor.float()
encoding_min = encoding_min.float()
encoding_max = encoding_max.float()

use_symmetric_encodings = tensor_quantizer.use_symmetric_encodings
is_unsigned_symmetric = tensor_quantizer.is_unsigned_symmetric
Expand Down Expand Up @@ -241,7 +245,7 @@ def calculate_forward_pass(tensor: torch.Tensor,
encoding_min, encoding_max,
delta, offset, mask_tensor, num_steps,
use_symmetric_encodings, is_unsigned_symmetric)
return x_dequant, intermediate_result
return x_dequant.to(orig_dtype), intermediate_result


# pylint:disable=too-many-locals
Expand Down
14 changes: 12 additions & 2 deletions TrainingExtensions/torch/test/python/test_qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,8 +957,9 @@ def test_learned_grid_wrapper_pickle_upickle(self):
with pytest.raises(RuntimeError):
loaded_quant_wrapper.output_quantizers[0].encoding = enc_new

@pytest.mark.parametrize('act_bw', [8, 16])
@pytest.mark.cuda
def test_learned_grid_preserves_fp16(self):
def test_learned_grid_preserves_fp16(self, act_bw):
"""
Test if the forward of LearnedGridQuantWrapper preserves the dtype between
input and output
Expand All @@ -979,7 +980,7 @@ def test_learned_grid_preserves_fp16(self):

linear_wrapper = LearnedGridQuantWrapper(linear,
weight_bw=4,
activation_bw=8,
activation_bw=act_bw,
round_mode='round_nearest',
quant_scheme=QuantScheme.training_range_learning_with_tf_init,
device='cuda:0')
Expand Down Expand Up @@ -1012,6 +1013,15 @@ def test_learned_grid_preserves_fp16(self):
assert linear_wrapper.output0_encoding_min.grad.dtype == torch.float16
assert linear_wrapper.output0_encoding_max.grad.dtype == torch.float16

for val in (torch.nan, torch.inf, -torch.inf):
assert torch.all(linear_wrapper.weight.grad != val)
assert torch.all(linear_wrapper.weight_encoding_min.grad != val)
assert torch.all(linear_wrapper.weight_encoding_max.grad != val)
assert torch.all(linear_wrapper.input0_encoding_min.grad != val)
assert torch.all(linear_wrapper.input0_encoding_max.grad != val)
assert torch.all(linear_wrapper.output0_encoding_min.grad != val)
assert torch.all(linear_wrapper.output0_encoding_max.grad != val)

@pytest.mark.parametrize("wrapper",
[StaticGridQuantWrapper(elementwise_ops.Addmm(),
8, 8, 'nearest',
Expand Down

0 comments on commit f961a96

Please sign in to comment.