From b72b2a0720ad82ff2269a5df3bd0484f04f7ea66 Mon Sep 17 00:00:00 2001 From: Kaeun Kim <51257208+k223kim@users.noreply.github.com> Date: Tue, 7 May 2024 04:36:59 +0900 Subject: [PATCH] added `expand_as` (#350) --- thunder/tests/opinfos.py | 49 +++++++++++++++++++++++++++++++++++++++ thunder/torch/__init__.py | 5 ++++ 2 files changed, 54 insertions(+) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 6f7de4dc76..f275139ca6 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -3121,6 +3121,55 @@ def expand_error_generator(op, device, *, dtype=torch.float32, **kwargs): shape_ops.append(expand_opinfo) +def expand_as_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Input shape, output shape + cases = ( + ((), ()), # Scalar identity + ((), (3, 4, 5)), # Broadcast scalar tensor, adding dims + ((0,), (0,)), # Zero dim tensor identity + ((1, 0), (1, 0)), # Nonleading zero dim + ((1, 0), (0, 0)), # Empty input (one broadcast, one zero) + ((1, 1), (0, 0)), # Non-empty fully broadcast input + ((1, 3), (1, 1, 3)), # Add dim + ((1, 1), (1, 2)), # Broadcast trailing dim + ((1, 1), (2, 1)), # Broadcast leading dim + ) + + for ishape, oshape in cases: + yield SampleInput(make(ishape), make(oshape)) + + +def expand_as_error_generator(op, device, *, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + + # Input shape, output shape, exception type, error message match or None for universal match + cases = [ + ((0,), (1,), RuntimeError, "attempting to expand a dimension of length 0"), + ((1,), (), RuntimeError, "expand: the requested shape has too few dimensions!"), + ((0,), (2,), RuntimeError, "attempting to expand a dimension of length 0"), + ((2, 2), (2, 4), RuntimeError, "attempting to expand a dimension of length 2"), + ] + + for ishape, oshape, exc_type, err_msg_match in cases: + yield SampleInput(make(ishape), make(oshape)), exc_type, err_msg_match + + +expand_as_opinfo = OpInfo( + ltorch.expand_as, + sample_input_generator=expand_as_sample_generator, + error_input_generator=expand_as_error_generator, + torch_reference=torch.Tensor.expand_as, + test_directives=( + # vjp and jvp not yet implemented + DecorateInfo(pytest.mark.xfail, "test_vjp_correctness"), + DecorateInfo(pytest.mark.xfail, "test_jvp_correctness"), + ), +) +shape_ops.append(expand_as_opinfo) + + def flatten_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index e0f7254ad6..0fdb668dbf 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -685,6 +685,11 @@ def expand(a: TensorLike, /, *shape: int) -> TensorLike: return clang.expand(a, *shape) +@torchsymbol(torch.Tensor.expand_as, is_method=True) +def expand_as(a: TensorLike, b: TensorLike, /) -> TensorLike: + return expand(a, b.size()) + + @torchsymbol(torch.flatten, is_method=True) def flatten(a: TensorLike, /, start_dim: int = 0, end_dim: int = -1) -> TensorLike: return clang.flatten(a, start_dim, end_dim)