Skip to content

Commit

Permalink
added expand_as (Lightning-AI#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
k223kim authored May 6, 2024
1 parent 56ca7d4 commit b72b2a0
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
49 changes: 49 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3121,6 +3121,55 @@ def expand_error_generator(op, device, *, dtype=torch.float32, **kwargs):
shape_ops.append(expand_opinfo)


def expand_as_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

# Input shape, output shape
cases = (
((), ()), # Scalar identity
((), (3, 4, 5)), # Broadcast scalar tensor, adding dims
((0,), (0,)), # Zero dim tensor identity
((1, 0), (1, 0)), # Nonleading zero dim
((1, 0), (0, 0)), # Empty input (one broadcast, one zero)
((1, 1), (0, 0)), # Non-empty fully broadcast input
((1, 3), (1, 1, 3)), # Add dim
((1, 1), (1, 2)), # Broadcast trailing dim
((1, 1), (2, 1)), # Broadcast leading dim
)

for ishape, oshape in cases:
yield SampleInput(make(ishape), make(oshape))


def expand_as_error_generator(op, device, *, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)

# Input shape, output shape, exception type, error message match or None for universal match
cases = [
((0,), (1,), RuntimeError, "attempting to expand a dimension of length 0"),
((1,), (), RuntimeError, "expand: the requested shape has too few dimensions!"),
((0,), (2,), RuntimeError, "attempting to expand a dimension of length 0"),
((2, 2), (2, 4), RuntimeError, "attempting to expand a dimension of length 2"),
]

for ishape, oshape, exc_type, err_msg_match in cases:
yield SampleInput(make(ishape), make(oshape)), exc_type, err_msg_match


expand_as_opinfo = OpInfo(
ltorch.expand_as,
sample_input_generator=expand_as_sample_generator,
error_input_generator=expand_as_error_generator,
torch_reference=torch.Tensor.expand_as,
test_directives=(
# vjp and jvp not yet implemented
DecorateInfo(pytest.mark.xfail, "test_vjp_correctness"),
DecorateInfo(pytest.mark.xfail, "test_jvp_correctness"),
),
)
shape_ops.append(expand_as_opinfo)


def flatten_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)

Expand Down
5 changes: 5 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,11 @@ def expand(a: TensorLike, /, *shape: int) -> TensorLike:
return clang.expand(a, *shape)


@torchsymbol(torch.Tensor.expand_as, is_method=True)
def expand_as(a: TensorLike, b: TensorLike, /) -> TensorLike:
return expand(a, b.size())


@torchsymbol(torch.flatten, is_method=True)
def flatten(a: TensorLike, /, start_dim: int = 0, end_dim: int = -1) -> TensorLike:
return clang.flatten(a, start_dim, end_dim)
Expand Down

0 comments on commit b72b2a0

Please sign in to comment.