Skip to content

Commit

Permalink
Allow nested dispatch (#3442)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Oct 28, 2024
1 parent 22cca45 commit eb22c2b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return super().__torch_function__(impl, types, args, kwargs)


_dispatcher = _Dispatcher()
_stack_level = 0

@contextlib.contextmanager
def _dispatch(torch_func: Callable, custom_impl: Callable):
# pylint: disable=global-statement
global _stack_level
orig_level = _stack_level

try:
orig = _dispatch_table[torch_func]
except KeyError as e:
Expand All @@ -441,17 +434,10 @@ def _dispatch(torch_func: Callable, custom_impl: Callable):
try:
_dispatch_table[torch_func] = custom_impl

if _stack_level == 0:
_dispatcher.__enter__()
_stack_level += 1

yield
with _Dispatcher():
yield
finally:
_dispatch_table[torch_func] = orig
_stack_level = orig_level

if _stack_level == 0:
_dispatcher.__exit__(None, None, None)


class _DispatchMeta(QuantizationMixinMeta):
Expand Down
26 changes: 24 additions & 2 deletions TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,13 @@ def test_remove_quantizers(self, input):


def test_dispatch_sanity():
"""
Given: custom_add(x, y) := x + y + 1
"""
custom_add = lambda *args, **kwargs: torch.add(*args, **kwargs) + 1

"""
When: Dispatch torch.add with custom_add(x, y) := x + y + 1
When: Dispatch custom_add in place of torch.add(x, y)
Then: Output of torch.add(x, y) should be equal to x + y + 1
"""
zeros = torch.zeros(10)
Expand All @@ -653,7 +656,7 @@ def test_dispatch_sanity():
assert torch.all(out == 1)

"""
When: Dispatch torch.add with custom_add(x, y) := x + y + 1
When: Dispatch custom_add in place of torch.add
Then: Output of the other functions should not be affected
"""
with _dispatch(torch.add, custom_add):
Expand All @@ -678,6 +681,25 @@ def test_dispatch_sanity():
with pytest.raises(RuntimeError):
with _dispatch(func, dummy_impl): pass

"""
When: Dispatch custom_addmm in place of torch.addmm in which
custom_add will be dispatched in place of torch.add in a nested fashion
Then: Output of torch.addmm(x, y, z) should be equal to x + (y @ z) + 1
"""
x = torch.randn(10, 10)
y = torch.randn(10, 10)
z = torch.randn(10, 10)

def custom_addmm(x, y, z):
with _dispatch(torch.add, custom_add):
return torch.add(x, torch.matmul(y, z))

with _dispatch(torch.addmm, custom_addmm):
out = torch.addmm(x, y, z)

expected = x + (y @ z) + 1
assert torch.all(out == expected)


def _create_legacy_fake_quantized_module(module):
qmodule = _legacy_impl.FakeQuantizationMixin.from_module(module)
Expand Down

0 comments on commit eb22c2b

Please sign in to comment.