Skip to content

Commit

Permalink
Refactor cross_entropy using log_softmax and nll_loss references (
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 authored Jun 5, 2024
1 parent a7017ce commit 4a7097d
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 301 deletions.
3 changes: 3 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,10 @@ def _div_prim_grad(a: Number | TensorProxy, b: Number | TensorProxy, /) -> Numbe

# Comparison operators -- these create no grad associations
register_grad(pids.EQ, prims.eq)
register_grad(pids.NE, prims.ne)
register_grad(pids.GE, prims.ge)
register_grad(pids.GT, prims.gt)
register_grad(pids.LE, prims.le)
register_grad(pids.LT, prims.lt)
register_grad(pids.NE, prims.ne)
register_grad(pids.GT, prims.gt)
Expand Down
63 changes: 2 additions & 61 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,6 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
conv2d = _register_torch_operation("conv2d", module=torch.nn.functional)
conv3d = _register_torch_operation("conv3d", module=torch.nn.functional)
mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional)
cross_entropy = _register_torch_operation("cross_entropy", module=torch.nn.functional)
dropout = _register_torch_operation("dropout", module=torch.nn.functional)
embedding = _register_torch_operation("embedding", module=torch.nn.functional)
embedding_backward = _register_torch_operation("torch.ops.aten.embedding_backward", like=ltorch.embedding_backward)
Expand Down Expand Up @@ -1446,59 +1445,6 @@ def _convolution_transform(
return convolution(a, weight, bias, stride, padding, dilation, bool(transposed), output_padding, groups)


def _cross_entropy_backward_impl(
g: torch.Tensor,
a: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor,
reduction: str,
ignore_index: int,
label_smoothing: int,
) -> torch.Tensor:
# forward - given input and target
# a = log_softmax(input, dim)
# return cross_entropy(a, target, weight, reduction, ignore_index, label_smoothing)

# backward - given grad_cross_entropy and saved_tensors
# grad_a = torch.ops.aten.nll_loss_backward(grad, a, target, weight, reduction, ignore_index, total_weight)
# return torch.ops.aten._log_softmax_backward_data(grad_a, a, dim, a.scalar_type())

if reduction == "none":
reduction_idx = 0
elif reduction == "mean":
reduction_idx = 1
elif reduction == "sum":
reduction_idx = 2
else:
reduction_idx = -1

utils.check(
reduction_idx > -1 and reduction_idx < 3,
lambda: f"{reduction} is not a valid value for reduction parameter.",
)

# TODO Add support nll_loss_nd, weight tensor, and label_smoothing options.
# See issue "Add support for remaining cross_entropy_loss arguments."
utils.check(a.ndim <= 2 and target.ndim <= 1, lambda: f"multi-dimension cross-entropy is not supported.")

utils.check(weight is None, lambda: f"weight tensor argument is not supported.")

utils.check(label_smoothing == 0.0, lambda: f"label smoothing values not equal to zero are not supported.")

dim = 0 if a.dim() == 1 else 1
a = torch.log_softmax(a, dim, a.dtype)

if weight is not None:
total_weight = torch.sum(weight)
elif reduction == "none":
total_weight = torch.tensor(0.0, dtype=a.dtype, device=a.device)
elif reduction == "sum" or reduction == "mean":
total_weight = torch.sum(torch.ne(target, ignore_index)).to(dtype=a.dtype, device=a.device)

g_a = torch.ops.aten.nll_loss_backward(g, a, target, weight, reduction_idx, ignore_index, total_weight)
return torch.ops.aten._log_softmax_backward_data(g_a, a, dim, a.dtype)


# NOTE PyTorch's nn.functional.interpolate only supports 3D, 4D, and 5D interpolation
def _interpolate_checker(
a: TensorLike,
Expand Down Expand Up @@ -1563,11 +1509,11 @@ def _nll_loss_backward_impl(
) -> torch.Tensor:
reduction: int = _reduction_str_to_num_map[reduction]

# NOTE PyTorch expects total_weight to be a float64 tensor
if total_weight is None:
# NOTE aten.nll_loss_backward expects total_weight to be a float64 tensor
total_weight = torch.tensor(0.0, dtype=torch.float64, device=a.device)
else:
total_weight = total_weight.to(torch.float64)
total_weight = total_weight.to(a.dtype)

if a.ndim <= 2:
return torch.ops.aten.nll_loss_backward(g, a, target, weight, reduction, ignore_index, total_weight)
Expand Down Expand Up @@ -1652,11 +1598,6 @@ def _pad_prim_impl(
_register_implementation(ltorch.conv2d, conv2d, checker=_always_executable)
_register_implementation(ltorch.conv3d, conv3d, checker=_always_executable)
_register_implementation(ltorch.mse_loss, mse_loss, checker=_always_executable)
_register_implementation(ltorch.cross_entropy, cross_entropy, checker=_always_executable)
cross_entropy_backward = ex.register_operator(
"torch_cross_entropy_backward_impl", meta=ltorch.cross_entropy_backward, fn=_cross_entropy_backward_impl
)
_register_implementation(ltorch.cross_entropy_backward, cross_entropy_backward, checker=_always_executable)
_register_implementation(ltorch.dropout, dropout, checker=_always_executable)
_register_implementation(ltorch.embedding, embedding, checker=_always_executable)
_register_implementation(ltorch.embedding_backward, embedding_backward, checker=_always_executable)
Expand Down
138 changes: 124 additions & 14 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7728,46 +7728,47 @@ def cross_entropy_reference_generator(op, device, dtype, requires_grad, **kwargs
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True)
),
weight=make(C) if weight_flag else None,
weight=make(C, requires_grad=False) if weight_flag else None,
ignore_index=ignore_index,
reduction=reduction_str,
label_smoothing=label_smoothing,
)


# TODO Enable cross entropy bwd weight support
# TODO Enable test cases after adding support nll_loss_nd, weight tensor, and label_smoothing options.
# TODO see issue "Add support for remaining cross_entropy_loss arguments"
def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

# input_shape, target_shape
shapes = (
((2, 16), (2,)),
((7, 18), (7,)),
# ((7, 18), (7, 18)),
# ((3, 4, 2, 3), (3, 4, 2, 3)),
# ((3, 4, 2, 3), (3, 2, 3)),
((7, 18), (7, 18)),
((3, 4, 2, 3), (3, 4, 2, 3)),
((3, 4, 2, 3), (3, 2, 3)),
((5,), ()),
# ((3, 4, 0), (3, 0)),
# ((3, 4, 0), (3, 4, 0)),
((3, 4, 0), (3, 0)),
((3, 4, 0), (3, 4, 0)),
)

weight_options = (False,)
weight_options = (False, True)
reduction_options = ("none", "mean", "sum")
label_smoothing_options = (0.0, 0.5)
ignore_index_options = (-1, 3)

for shape, weight_flag, reduction_str, label_smoothing, ignore_index in itertools.product(
shapes, weight_options, reduction_options, label_smoothing_options, ignore_index_options
):
# NOTE According to pytorch/pytorch#64572, nll_loss should return NaN when reduction = "mean"
# and the whole target is equal to ignore_index. However, if the inputs are cuda tensors, PyTorch returns 0.
# Skip this case because we are consistent across devices.
if torch.device(device).type == "cuda" and reduction_str == "mean" and ignore_index > 0:
continue

input_shape, target_shape = shape
probability_target = input_shape == target_shape
# ignore_index can't be supplied with probablity target
if probability_target and ignore_index >= 0:
continue
if not probability_target and label_smoothing > 0.0:
continue
C = input_shape[1] if len(input_shape) >= 2 else input_shape[0]
yield SampleInput(
make(shape[0]),
Expand All @@ -7776,21 +7777,125 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs):
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True)
),
weight=make(C) if weight_flag else None,
weight=make(C, requires_grad=False) if weight_flag else None,
ignore_index=ignore_index,
reduction=reduction_str,
label_smoothing=label_smoothing,
)


def cross_entropy_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)

input_shape = (7, 18)
target_shape = (7,)
C = input_shape[1] if len(input_shape) >= 2 else input_shape[0]
valid_input = make(input_shape)
valid_target = make(target_shape, low=0, high=C, dtype=torch.long, requires_grad=False)

# unexpected reduction string argument
yield (
SampleInput(valid_input, valid_target, reduction="foo"),
ValueError,
'Expected reduction string to be "none", "sum", or "mean", but it is (.*?).',
)

# target tensor is not integer dtype
float_target = make(target_shape, low=0, high=C, dtype=torch.float, requires_grad=False)
yield (
SampleInput(valid_input, float_target),
RuntimeError,
"Expected target to be a tensor with an integer dtype, but it has dtype (.*?).",
)

# input tensor has 0 dimensions
scalar_input = make(scalar_shape := ())
yield (
SampleInput(scalar_input, valid_target),
RuntimeError,
f"Expected the input tensor to have more than 1 dimension, but it has {scalar_input.ndim} dimensions.",
)

# weight tensor has more than 1 dimension
multiple_dim_weight = make((C, 10), requires_grad=False)
yield (
SampleInput(valid_input, valid_target, weight=multiple_dim_weight),
RuntimeError,
f"Expected a 1D tensor with {C} elements for weight argument, \
but found a tensor with {multiple_dim_weight.ndim} dimensions and {multiple_dim_weight.shape[0]} elements.",
)

# weight tensor numel != C
incorrect_numel_weight = make((C + 10,), requires_grad=False)
yield (
SampleInput(valid_input, valid_target, weight=incorrect_numel_weight),
RuntimeError,
f"Expected a 1D tensor with {C} elements for weight argument, \
but found a tensor with {incorrect_numel_weight.ndim} dimensions and {incorrect_numel_weight.shape[0]} elements.",
)

# label_smoothing out-of-bounds
out_of_bounds_label_smoothing = 1.5
yield (
SampleInput(valid_input, valid_target, label_smoothing=out_of_bounds_label_smoothing),
RuntimeError,
r"Expected label_smoothing to be in \[0, 1\] range but got 1.5.",
)

# target tensor is not integer dtype
float_target = make(target_shape, low=0, high=C, dtype=torch.float, requires_grad=False)
yield (
SampleInput(valid_input, float_target),
RuntimeError,
"Expected target to be a tensor with an integer dtype, but it has dtype (.*?).",
)

# input ndims != (target ndims + 1)
extra_dim_input = make(input_shape + (10,))
yield (
SampleInput(extra_dim_input, valid_target),
RuntimeError,
"Expected the input tensor to have (.*?) dimensions, but it has (.*?) dimensions.",
)

# target shape is input shape except channels dimension
incorrect_batch_target = make((10,), low=0, high=C, dtype=torch.long, requires_grad=False)
yield (
SampleInput(valid_input, incorrect_batch_target),
RuntimeError,
"Expected the target tensor to have the same shape as the input tensor except for the channels dimension \
(.*?), but it has shape (.*?).",
)

integer_prob_target = make(input_shape, low=0, high=C, dtype=torch.long, requires_grad=False)
yield (
SampleInput(valid_input, integer_prob_target),
RuntimeError,
"Expected the target to have float dtype when target contains class probabilities \
but it is (.*?).",
)

valid_prob_target = make(input_shape, low=0.0, high=1.0, dtype=torch.float, requires_grad=False)
yield (
SampleInput(valid_input, valid_prob_target, ignore_index=5),
RuntimeError,
"ignore_index argument is not supported when target contains class probabilities.",
)


cross_entropy_opinfo = OpInfo(
ltorch.cross_entropy,
supports_grad=True,
sample_input_generator=cross_entropy_sample_generator,
reference_input_generator=cross_entropy_reference_generator,
error_input_generator=cross_entropy_error_generator,
torch_reference=torch.nn.functional.cross_entropy,
dtypes=(datatypes.floating,),
test_directives=(
# take_along_axis is disabled with nvfuser, which this operator relies on.
DecorateInfo(
pytest.mark.skip,
executors=("nvfuser",),
),
# TODO Investigate why CPU torch executor tests fail in CI (but not locally)
DecorateInfo(
pytest.mark.skip,
Expand All @@ -7805,6 +7910,11 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs):
datatypes.bfloat16,
),
),
# TODO FIXME -- These tests are hitting an odd issue where real torch tensors are being passed to nll_loss
DecorateInfo(
pytest.mark.skip,
"test_vjp_correctness",
),
),
)
nn_ops.append(cross_entropy_opinfo)
Expand Down
23 changes: 22 additions & 1 deletion thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
"amax",
"amin",
"cat",
"cross_entropy",
"softmax",
"to",
"linear",
Expand Down Expand Up @@ -609,6 +608,28 @@ def test_vjp_correctness_nll_loss_manual(op, device, dtype, executor, comp):
comp(grad_out[0], expected_grad[0])


@ops((get_opinfo("cross_entropy"),), supported_dtypes=(dtypes.float64,))
def test_vjp_correctness_cross_entropy_manual(op, device, dtype, executor, comp):
for sample in op.sample_inputs(device, dtype, requires_grad=True, no_rhs_numbers=True):
# Traced backwards function does not follow PyTorch cross_entropy behavior with zero element tensors
if sample.args[0].numel() == 0:
continue

# Compute vjp result using PyTorch
out = op.torch_reference(*sample.args, **sample.kwargs)
v = make_tensor_like(out)
expected_grad = torch.autograd.grad(out, sample.args[0], v)

# Compute vjp result using Thunder
flat_op, flat_args, spec = flatten_func(op.op, sample.args, sample.kwargs)
actual_out, grad_out = executor.make_callable_legacy(vjp(flat_op), disable_torch_autograd_support=True)(
flat_args, (v,)
)

comp(actual_out, out)
comp(grad_out[0], expected_grad[0])


@ops((get_opinfo("einsum"),), supported_dtypes=(dtypes.float64,))
def test_vjp_correctness_einsum_manual(op, device, dtype, executor, comp):
from thunder.tests.framework import nvFuserTestExecutor
Expand Down
Loading

0 comments on commit 4a7097d

Please sign in to comment.