From eb22c2bcf32a68eda09d14fa2c88a6e8804e9108 Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Mon, 28 Oct 2024 13:14:48 -0700 Subject: [PATCH] Allow nested dispatch (#3442) Signed-off-by: Kyunggeun Lee --- .../python/aimet_torch/v2/nn/true_quant.py | 18 ++----------- .../test/python/v2/nn/test_true_quant.py | 26 +++++++++++++++++-- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py index 77dbf3d1353..22e8459eee8 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py @@ -424,15 +424,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return super().__torch_function__(impl, types, args, kwargs) -_dispatcher = _Dispatcher() -_stack_level = 0 - @contextlib.contextmanager def _dispatch(torch_func: Callable, custom_impl: Callable): - # pylint: disable=global-statement - global _stack_level - orig_level = _stack_level - try: orig = _dispatch_table[torch_func] except KeyError as e: @@ -441,17 +434,10 @@ def _dispatch(torch_func: Callable, custom_impl: Callable): try: _dispatch_table[torch_func] = custom_impl - if _stack_level == 0: - _dispatcher.__enter__() - _stack_level += 1 - - yield + with _Dispatcher(): + yield finally: _dispatch_table[torch_func] = orig - _stack_level = orig_level - - if _stack_level == 0: - _dispatcher.__exit__(None, None, None) class _DispatchMeta(QuantizationMixinMeta): diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py index ee1818adee1..3b0ddf6aef6 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py +++ b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py @@ -637,10 +637,13 @@ def test_remove_quantizers(self, input): def test_dispatch_sanity(): + """ + Given: custom_add(x, y) := x + y + 1 + """ custom_add = lambda *args, **kwargs: torch.add(*args, **kwargs) + 1 """ - When: Dispatch torch.add with custom_add(x, y) := x + y + 1 + When: Dispatch custom_add in place of torch.add(x, y) Then: Output of torch.add(x, y) should be equal to x + y + 1 """ zeros = torch.zeros(10) @@ -653,7 +656,7 @@ def test_dispatch_sanity(): assert torch.all(out == 1) """ - When: Dispatch torch.add with custom_add(x, y) := x + y + 1 + When: Dispatch custom_add in place of torch.add Then: Output of the other functions should not be affected """ with _dispatch(torch.add, custom_add): @@ -678,6 +681,25 @@ def test_dispatch_sanity(): with pytest.raises(RuntimeError): with _dispatch(func, dummy_impl): pass + """ + When: Dispatch custom_addmm in place of torch.addmm in which + custom_add will be dispatched in place of torch.add in a nested fashion + Then: Output of torch.addmm(x, y, z) should be equal to x + (y @ z) + 1 + """ + x = torch.randn(10, 10) + y = torch.randn(10, 10) + z = torch.randn(10, 10) + + def custom_addmm(x, y, z): + with _dispatch(torch.add, custom_add): + return torch.add(x, torch.matmul(y, z)) + + with _dispatch(torch.addmm, custom_addmm): + out = torch.addmm(x, y, z) + + expected = x + (y @ z) + 1 + assert torch.all(out == expected) + def _create_legacy_fake_quantized_module(module): qmodule = _legacy_impl.FakeQuantizationMixin.from_module(module)