From 4a7097d6346ec522877d751818150966ccf6a3b9 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Wed, 5 Jun 2024 09:34:00 -0700 Subject: [PATCH] Refactor `cross_entropy` using `log_softmax` and `nll_loss` references (#260) --- thunder/core/transforms.py | 3 + thunder/executors/torchex.py | 63 +----- thunder/tests/opinfos.py | 138 +++++++++++-- thunder/tests/test_grad.py | 23 ++- thunder/torch/__init__.py | 374 ++++++++++++++--------------------- 5 files changed, 300 insertions(+), 301 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index cd738fb774..0de25f5b72 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1091,7 +1091,10 @@ def _div_prim_grad(a: Number | TensorProxy, b: Number | TensorProxy, /) -> Numbe # Comparison operators -- these create no grad associations register_grad(pids.EQ, prims.eq) +register_grad(pids.NE, prims.ne) register_grad(pids.GE, prims.ge) +register_grad(pids.GT, prims.gt) +register_grad(pids.LE, prims.le) register_grad(pids.LT, prims.lt) register_grad(pids.NE, prims.ne) register_grad(pids.GT, prims.gt) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index e7f3f4c2b5..d28fe39361 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1311,7 +1311,6 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: conv2d = _register_torch_operation("conv2d", module=torch.nn.functional) conv3d = _register_torch_operation("conv3d", module=torch.nn.functional) mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional) -cross_entropy = _register_torch_operation("cross_entropy", module=torch.nn.functional) dropout = _register_torch_operation("dropout", module=torch.nn.functional) embedding = _register_torch_operation("embedding", module=torch.nn.functional) embedding_backward = _register_torch_operation("torch.ops.aten.embedding_backward", like=ltorch.embedding_backward) @@ -1446,59 +1445,6 @@ def _convolution_transform( return convolution(a, weight, bias, stride, padding, dilation, bool(transposed), output_padding, groups) -def _cross_entropy_backward_impl( - g: torch.Tensor, - a: torch.Tensor, - target: torch.Tensor, - weight: torch.Tensor, - reduction: str, - ignore_index: int, - label_smoothing: int, -) -> torch.Tensor: - # forward - given input and target - # a = log_softmax(input, dim) - # return cross_entropy(a, target, weight, reduction, ignore_index, label_smoothing) - - # backward - given grad_cross_entropy and saved_tensors - # grad_a = torch.ops.aten.nll_loss_backward(grad, a, target, weight, reduction, ignore_index, total_weight) - # return torch.ops.aten._log_softmax_backward_data(grad_a, a, dim, a.scalar_type()) - - if reduction == "none": - reduction_idx = 0 - elif reduction == "mean": - reduction_idx = 1 - elif reduction == "sum": - reduction_idx = 2 - else: - reduction_idx = -1 - - utils.check( - reduction_idx > -1 and reduction_idx < 3, - lambda: f"{reduction} is not a valid value for reduction parameter.", - ) - - # TODO Add support nll_loss_nd, weight tensor, and label_smoothing options. - # See issue "Add support for remaining cross_entropy_loss arguments." - utils.check(a.ndim <= 2 and target.ndim <= 1, lambda: f"multi-dimension cross-entropy is not supported.") - - utils.check(weight is None, lambda: f"weight tensor argument is not supported.") - - utils.check(label_smoothing == 0.0, lambda: f"label smoothing values not equal to zero are not supported.") - - dim = 0 if a.dim() == 1 else 1 - a = torch.log_softmax(a, dim, a.dtype) - - if weight is not None: - total_weight = torch.sum(weight) - elif reduction == "none": - total_weight = torch.tensor(0.0, dtype=a.dtype, device=a.device) - elif reduction == "sum" or reduction == "mean": - total_weight = torch.sum(torch.ne(target, ignore_index)).to(dtype=a.dtype, device=a.device) - - g_a = torch.ops.aten.nll_loss_backward(g, a, target, weight, reduction_idx, ignore_index, total_weight) - return torch.ops.aten._log_softmax_backward_data(g_a, a, dim, a.dtype) - - # NOTE PyTorch's nn.functional.interpolate only supports 3D, 4D, and 5D interpolation def _interpolate_checker( a: TensorLike, @@ -1563,11 +1509,11 @@ def _nll_loss_backward_impl( ) -> torch.Tensor: reduction: int = _reduction_str_to_num_map[reduction] - # NOTE PyTorch expects total_weight to be a float64 tensor if total_weight is None: + # NOTE aten.nll_loss_backward expects total_weight to be a float64 tensor total_weight = torch.tensor(0.0, dtype=torch.float64, device=a.device) else: - total_weight = total_weight.to(torch.float64) + total_weight = total_weight.to(a.dtype) if a.ndim <= 2: return torch.ops.aten.nll_loss_backward(g, a, target, weight, reduction, ignore_index, total_weight) @@ -1652,11 +1598,6 @@ def _pad_prim_impl( _register_implementation(ltorch.conv2d, conv2d, checker=_always_executable) _register_implementation(ltorch.conv3d, conv3d, checker=_always_executable) _register_implementation(ltorch.mse_loss, mse_loss, checker=_always_executable) -_register_implementation(ltorch.cross_entropy, cross_entropy, checker=_always_executable) -cross_entropy_backward = ex.register_operator( - "torch_cross_entropy_backward_impl", meta=ltorch.cross_entropy_backward, fn=_cross_entropy_backward_impl -) -_register_implementation(ltorch.cross_entropy_backward, cross_entropy_backward, checker=_always_executable) _register_implementation(ltorch.dropout, dropout, checker=_always_executable) _register_implementation(ltorch.embedding, embedding, checker=_always_executable) _register_implementation(ltorch.embedding_backward, embedding_backward, checker=_always_executable) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index fa0aad8a32..f625f041ca 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -7728,16 +7728,13 @@ def cross_entropy_reference_generator(op, device, dtype, requires_grad, **kwargs if not probability_target else make(shape[1], low=0.0, high=1.0, requires_grad=True) ), - weight=make(C) if weight_flag else None, + weight=make(C, requires_grad=False) if weight_flag else None, ignore_index=ignore_index, reduction=reduction_str, label_smoothing=label_smoothing, ) -# TODO Enable cross entropy bwd weight support -# TODO Enable test cases after adding support nll_loss_nd, weight tensor, and label_smoothing options. -# TODO see issue "Add support for remaining cross_entropy_loss arguments" def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -7745,15 +7742,15 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs): shapes = ( ((2, 16), (2,)), ((7, 18), (7,)), - # ((7, 18), (7, 18)), - # ((3, 4, 2, 3), (3, 4, 2, 3)), - # ((3, 4, 2, 3), (3, 2, 3)), + ((7, 18), (7, 18)), + ((3, 4, 2, 3), (3, 4, 2, 3)), + ((3, 4, 2, 3), (3, 2, 3)), ((5,), ()), - # ((3, 4, 0), (3, 0)), - # ((3, 4, 0), (3, 4, 0)), + ((3, 4, 0), (3, 0)), + ((3, 4, 0), (3, 4, 0)), ) - weight_options = (False,) + weight_options = (False, True) reduction_options = ("none", "mean", "sum") label_smoothing_options = (0.0, 0.5) ignore_index_options = (-1, 3) @@ -7761,13 +7758,17 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs): for shape, weight_flag, reduction_str, label_smoothing, ignore_index in itertools.product( shapes, weight_options, reduction_options, label_smoothing_options, ignore_index_options ): + # NOTE According to pytorch/pytorch#64572, nll_loss should return NaN when reduction = "mean" + # and the whole target is equal to ignore_index. However, if the inputs are cuda tensors, PyTorch returns 0. + # Skip this case because we are consistent across devices. + if torch.device(device).type == "cuda" and reduction_str == "mean" and ignore_index > 0: + continue + input_shape, target_shape = shape probability_target = input_shape == target_shape # ignore_index can't be supplied with probablity target if probability_target and ignore_index >= 0: continue - if not probability_target and label_smoothing > 0.0: - continue C = input_shape[1] if len(input_shape) >= 2 else input_shape[0] yield SampleInput( make(shape[0]), @@ -7776,21 +7777,125 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs): if not probability_target else make(shape[1], low=0.0, high=1.0, requires_grad=True) ), - weight=make(C) if weight_flag else None, + weight=make(C, requires_grad=False) if weight_flag else None, ignore_index=ignore_index, reduction=reduction_str, label_smoothing=label_smoothing, ) +def cross_entropy_error_generator(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + input_shape = (7, 18) + target_shape = (7,) + C = input_shape[1] if len(input_shape) >= 2 else input_shape[0] + valid_input = make(input_shape) + valid_target = make(target_shape, low=0, high=C, dtype=torch.long, requires_grad=False) + + # unexpected reduction string argument + yield ( + SampleInput(valid_input, valid_target, reduction="foo"), + ValueError, + 'Expected reduction string to be "none", "sum", or "mean", but it is (.*?).', + ) + + # target tensor is not integer dtype + float_target = make(target_shape, low=0, high=C, dtype=torch.float, requires_grad=False) + yield ( + SampleInput(valid_input, float_target), + RuntimeError, + "Expected target to be a tensor with an integer dtype, but it has dtype (.*?).", + ) + + # input tensor has 0 dimensions + scalar_input = make(scalar_shape := ()) + yield ( + SampleInput(scalar_input, valid_target), + RuntimeError, + f"Expected the input tensor to have more than 1 dimension, but it has {scalar_input.ndim} dimensions.", + ) + + # weight tensor has more than 1 dimension + multiple_dim_weight = make((C, 10), requires_grad=False) + yield ( + SampleInput(valid_input, valid_target, weight=multiple_dim_weight), + RuntimeError, + f"Expected a 1D tensor with {C} elements for weight argument, \ + but found a tensor with {multiple_dim_weight.ndim} dimensions and {multiple_dim_weight.shape[0]} elements.", + ) + + # weight tensor numel != C + incorrect_numel_weight = make((C + 10,), requires_grad=False) + yield ( + SampleInput(valid_input, valid_target, weight=incorrect_numel_weight), + RuntimeError, + f"Expected a 1D tensor with {C} elements for weight argument, \ + but found a tensor with {incorrect_numel_weight.ndim} dimensions and {incorrect_numel_weight.shape[0]} elements.", + ) + + # label_smoothing out-of-bounds + out_of_bounds_label_smoothing = 1.5 + yield ( + SampleInput(valid_input, valid_target, label_smoothing=out_of_bounds_label_smoothing), + RuntimeError, + r"Expected label_smoothing to be in \[0, 1\] range but got 1.5.", + ) + + # target tensor is not integer dtype + float_target = make(target_shape, low=0, high=C, dtype=torch.float, requires_grad=False) + yield ( + SampleInput(valid_input, float_target), + RuntimeError, + "Expected target to be a tensor with an integer dtype, but it has dtype (.*?).", + ) + + # input ndims != (target ndims + 1) + extra_dim_input = make(input_shape + (10,)) + yield ( + SampleInput(extra_dim_input, valid_target), + RuntimeError, + "Expected the input tensor to have (.*?) dimensions, but it has (.*?) dimensions.", + ) + + # target shape is input shape except channels dimension + incorrect_batch_target = make((10,), low=0, high=C, dtype=torch.long, requires_grad=False) + yield ( + SampleInput(valid_input, incorrect_batch_target), + RuntimeError, + "Expected the target tensor to have the same shape as the input tensor except for the channels dimension \ + (.*?), but it has shape (.*?).", + ) + + integer_prob_target = make(input_shape, low=0, high=C, dtype=torch.long, requires_grad=False) + yield ( + SampleInput(valid_input, integer_prob_target), + RuntimeError, + "Expected the target to have float dtype when target contains class probabilities \ + but it is (.*?).", + ) + + valid_prob_target = make(input_shape, low=0.0, high=1.0, dtype=torch.float, requires_grad=False) + yield ( + SampleInput(valid_input, valid_prob_target, ignore_index=5), + RuntimeError, + "ignore_index argument is not supported when target contains class probabilities.", + ) + + cross_entropy_opinfo = OpInfo( ltorch.cross_entropy, - supports_grad=True, sample_input_generator=cross_entropy_sample_generator, reference_input_generator=cross_entropy_reference_generator, + error_input_generator=cross_entropy_error_generator, torch_reference=torch.nn.functional.cross_entropy, dtypes=(datatypes.floating,), test_directives=( + # take_along_axis is disabled with nvfuser, which this operator relies on. + DecorateInfo( + pytest.mark.skip, + executors=("nvfuser",), + ), # TODO Investigate why CPU torch executor tests fail in CI (but not locally) DecorateInfo( pytest.mark.skip, @@ -7805,6 +7910,11 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs): datatypes.bfloat16, ), ), + # TODO FIXME -- These tests are hitting an odd issue where real torch tensors are being passed to nll_loss + DecorateInfo( + pytest.mark.skip, + "test_vjp_correctness", + ), ), ) nn_ops.append(cross_entropy_opinfo) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 2a43195f74..7616c1054b 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -49,7 +49,6 @@ "amax", "amin", "cat", - "cross_entropy", "softmax", "to", "linear", @@ -609,6 +608,28 @@ def test_vjp_correctness_nll_loss_manual(op, device, dtype, executor, comp): comp(grad_out[0], expected_grad[0]) +@ops((get_opinfo("cross_entropy"),), supported_dtypes=(dtypes.float64,)) +def test_vjp_correctness_cross_entropy_manual(op, device, dtype, executor, comp): + for sample in op.sample_inputs(device, dtype, requires_grad=True, no_rhs_numbers=True): + # Traced backwards function does not follow PyTorch cross_entropy behavior with zero element tensors + if sample.args[0].numel() == 0: + continue + + # Compute vjp result using PyTorch + out = op.torch_reference(*sample.args, **sample.kwargs) + v = make_tensor_like(out) + expected_grad = torch.autograd.grad(out, sample.args[0], v) + + # Compute vjp result using Thunder + flat_op, flat_args, spec = flatten_func(op.op, sample.args, sample.kwargs) + actual_out, grad_out = executor.make_callable_legacy(vjp(flat_op), disable_torch_autograd_support=True)( + flat_args, (v,) + ) + + comp(actual_out, out) + comp(grad_out[0], expected_grad[0]) + + @ops((get_opinfo("einsum"),), supported_dtypes=(dtypes.float64,)) def test_vjp_correctness_einsum_manual(op, device, dtype, executor, comp): from thunder.tests.framework import nvFuserTestExecutor diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 67a6cb6d5c..9469f58d47 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3559,20 +3559,6 @@ def _dropout_helper(a, p): return result -# TODO Add annotations, make not a prim -# The backward decomposition of cross_entropy cannot be efficiently fused, so we have this cross_entropy_backward -# primitive. Executors can override the primitive using internal implementations. -# See issue "Cross_entropy is decomposed for backward but the decomposition is -# not fusible currently" -@torchsymbol("cross_entropy_backward", id="cross_entropy_backward", is_prim=True) -def cross_entropy_backward(g, a, /, target, weight, reduction, ignore_index, label_smoothing): - return TensorProxy(like=g, shape=a.shape) - - -# TODO (mruberry) -- I think this implementation gets the dtype of the output incorrect -# TODO Revise this to consistently call other torch operations where possible -# TODO -- Maybe cut this up into _cross_entropy_mean, _cross_entropy_sum, ... -# TODO Add type annotations @torchsymbol(torch.nn.functional.cross_entropy) def cross_entropy( a: TensorLike, @@ -3590,253 +3576,191 @@ def cross_entropy( lambda: f"Deprecated size_average={size_average} and reduce={reduce} is not supported!", ) + _cross_entropy_input_checks(a, target, weight, ignore_index, reduction, label_smoothing) + + # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # or right next to it (i.e. a.shape[1]). + channels_dim = 1 if a.ndim >= 2 else 0 + + # NOTE This short-circuit is subject to change and is placed ahead of other input checks to match PyTorch behavior. + # The expected behavior when the target and input have zero elements: + # reduction = 'none' --- tensor([], shape) or tensor(0.) + # reduction = 'sum' --- tensor(0.) + # reduction = 'mean' --- tensor(nan) + # Mean reduction on empty tensors produces NaN. + if a.numel() == 0: + if reduction == "none": + output_shape = list(a.shape) + output_shape.pop(channels_dim) + return full(output_shape, 0.0, device=a.device, dtype=a.dtype) + elif reduction == "sum": + return full(result_shape := [], fill_value := 0.0, device=a.device, dtype=a.dtype) + elif reduction == "mean": + return full(result_shape := [], fill_value := float("nan"), device=a.device, dtype=a.dtype) + + if a.shape == target.shape: + return _cross_entropy_loss_probability_target(a, target, weight, ignore_index, reduction, label_smoothing) + elif label_smoothing != 0.0: + return _cross_entropy_loss_label_smoothing(a, target, weight, ignore_index, reduction, label_smoothing) + else: + log_softmax_input = log_softmax(a, dim=channels_dim) + return nll_loss(log_softmax_input, target, weight, ignore_index, reduction) + + +def _cross_entropy_input_checks( + a: TensorLike, + /, + target: TensorLike, + weight: None | TensorLike, + ignore_index: int, + reduction: str, + label_smoothing: float, +): utils.check( - a.ndim != 0, - lambda: f"Cross entropy expects its input to have one or more dimensions, but it had zero dimensions", + reduction in ("none", "sum", "mean"), + lambda: f'Expected reduction string to be "none", "sum", or "mean", but it is {reduction}.', + exception_type=ValueError, ) - # NOTE label_smoothing < 0 will just be ignored. utils.check( - label_smoothing <= 1.0, - lambda: f"Cross entropy's {label_smoothing=} must be less than or equal to 1.0", + a.ndim >= 1, + lambda: f"Expected the input tensor to have more than 1 dimension, but it has {a.ndim} dimensions.", ) - # extract shape information - C_dim = 1 if a.ndim >= 2 else 0 - N = a.shape[0] if a.ndim >= 2 else 1 - C = a.shape[C_dim] - feature_size = int(a.numel() / N / C) + utils.check( + label_smoothing >= 0.0 and label_smoothing <= 1.0, + lambda: f"Expected label_smoothing to be in [0, 1] range but got {label_smoothing}.", + ) - # Short-circuits if a is empty - if a.numel() == 0: - if reduction == "none": - output_shape = list(a.shape) - output_shape.pop(C_dim) - return zeros(output_shape, device=a.device, dtype=a.dtype) - elif reduction == "mean": - fill_value = float("nan") - elif reduction == "sum": - fill_value = 0.0 - else: - raise ValueError(f"Reduction argument {reduction} to cross_entropy is not supported") + # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # or right next to it (i.e. a.shape[1]). + channels_dim = 1 if a.ndim >= 2 else 0 + num_channels = a.shape[channels_dim] - return full([], fill_value, device=a.device, dtype=a.dtype) + utils.check( + weight is None or (weight.ndim == 1 and weight.shape[0] == num_channels), + lambda: f"Expected a 1D tensor with {num_channels} elements for weight argument, \ + but found a tensor with {weight.ndim} dimensions and {weight.shape[0]} elements.", + ) - if weight is not None: + if a.shape != target.shape: utils.check( - weight.ndim == 1 and weight.numel() == C, - lambda: f"Expected {weight.shape=} to have one dimension and {C} elements", + utils.is_integer_dtype(target.dtype), + lambda: f"Expected target to be a tensor with an integer dtype, but it has dtype {target.dtype}.", ) - bcast_weight = reshape(weight, [C] + [1 for i in range(2, a.ndim)]) - log_softmax_a = log_softmax(a, C_dim) - out = -log_softmax_a + utils.check( + a.ndim == target.ndim + 1, + lambda: f"Expected the input tensor to have {(target.ndim + 1)=} dimensions, but it has {a.ndim} dimensions.", + ) - if a.shape == target.shape: + # target should match input in dims which do not correspond to the channels dim, i.e. + # (input.shape[:channels_dim] + input.shape[channels_dim + 1:]) == target.shape <=> True + expected_target_shape = a.shape[:channels_dim] + a.shape[channels_dim + 1 :] + + utils.check( + expected_target_shape == target.shape, + lambda: f"Expected the target tensor to have the same shape as the input tensor except for the channels dimension \ + {expected_target_shape}, but it has shape {target.shape}.", + ) + else: + # target represents class probabilities and is the range [0.0, 1.0] utils.check( utils.is_float_dtype(target.dtype), - lambda: f"expect float dtype for probability target, but got: {target.dtype}!", + lambda: f"Expected the target to have float dtype when target contains class probabilities \ + but it is {target.dtype}.", ) utils.check( ignore_index < 0, - lambda: f"ignore_index is not supported for probability target, set ignore_index < 0!", + lambda: "ignore_index argument is not supported when target contains class probabilities.", ) - if label_smoothing > 0.0: - target = target * (1 - label_smoothing) + label_smoothing / C - out = out * target - - if weight is not None: - out = out * bcast_weight - - if target.ndim == 1: - out = _reduction( - out, - prims.sum, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - else: - out = _reduction( - out, - prims.sum, - dims=C_dim, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) +def _cross_entropy_loss_probability_target( + a: TensorLike, + /, + target: TensorLike, + weight: None | TensorLike, + ignore_index: int, + reduction: str, + label_smoothing: float, +) -> TensorLike: + # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # or right next to it (i.e. a.shape[1]). + channels_dim = 1 if a.ndim >= 2 else 0 + num_channels = a.shape[channels_dim] - if reduction == "none": - return out - # TODO: duplicate this in probability target! - elif reduction == "sum": - # NOTE: do we need to promote dtype?! - return _reduction( - out, - prims.sum, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - elif reduction == "mean": - reduced_sum = _reduction( - out, - prims.sum, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - # NOTE: does it work with dynamic size?! - return reduced_sum / N * feature_size - else: - raise ValueError(f"reduction argument: {reduction} to cross_entropy is not supported") - else: - utils.check( - utils.is_integer_dtype(target.dtype), - lambda: f"expect integer dtype for class indices target, but got: {target.dtype}!", - ) - no_C_shape = list(a.shape) - no_C_shape.pop(C_dim) - utils.check( - a.ndim == target.ndim + 1 and no_C_shape == list(target.shape), - lambda: f"Inconsistent shape input: {a.shape} / target: {a.shape} to cross_entropy!", - ) + if label_smoothing > 0.0: + target = (target * (1 - label_smoothing)) + (label_smoothing / num_channels) - # nll_loss - if weight is not None: - out = out * bcast_weight - - smooth_loss_no_sum = out - # TODO: swap reshape with unsqueeze when nvfuser support is added - # bcast_target = clang.unsqueeze(target, [C_dim]) - bcast_target_shape = list(a.shape) - bcast_target_shape[C_dim] = 1 - bcast_target = reshape(target, bcast_target_shape) - - out = clang.take_along_axis(out, bcast_target, C_dim) - - if label_smoothing > 0: - # smooth_loss shape [N, SPATIAL...] - smooth_loss = _reduction( - smooth_loss_no_sum, - prims.sum, - dims=[C_dim], - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - # NOTE: [handling of 'ignore_index'] - # Semantically, I think we are doing the right thing here where we mask out the ignore_index entries on output from clang.take_along_axis. Because targets is expected to be within [0, C) - # However, in Torch/ATen implementation, 'ignore_index' can be outside of the range, so is targets. So it could even prevent an out-of-bound error from NLLLoss. Which diverges from the behavior here. - # Note that we can mimic that behavior by mask targets before take_along_axis, but that's going to add more operations here, which means more overhead. Let's not do that until we see real examples exploiting the behavior. - # Alternatively, we can revisit the choice of numpy.take_along_axis. - # jax.numpy.take_along_axis gives a 'mode' arg custom out-of-bound behavior. But that might be slightly tricky to handle for codegen. - if ignore_index >= 0: - # mask shape [N, 1, SPATIAL...] - mask = bcast_target == ignore_index - out = where(mask, 0, out) - if label_smoothing > 0: - # TODO: switch to squeeze - smooth_loss = where(reshape(mask, list(smooth_loss.shape)), 0, smooth_loss) + out = log_softmax(a, dim=channels_dim) * target - if reduction == "none": - # TODO: swap reshape with squeeze when nvfuser support is added - # return clang.squeeze(out, [C_dim]) - out = reshape(out, target.shape) - if label_smoothing > 0: - ret = smooth_loss - # TODO: duplicate this in probability target! - elif reduction == "sum": - # NOTE: do we need to promote dtype?! - out = _reduction( - out, - prims.sum, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - if label_smoothing > 0: - ret = _reduction( - smooth_loss, - prims.sum, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - elif reduction == "mean": - # NOTE: do we need to promote dtype?! - reduced_sum = _reduction( - out, - prims.sum, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - if label_smoothing > 0: - ret = _reduction( - smooth_loss, - prims.sum, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - if weight is not None: - # NOTE: this seems unreasonably complicated. Am I missing something obvious?! - input_shape = list(a.shape) - expanded_weight = clang.expand(bcast_weight, input_shape) - # DEBUG!!! this gives segfaults - selected_weight = clang.take_along_axis(expanded_weight, bcast_target, C_dim) - - if ignore_index >= 0: - selected_weight = where(mask, 0, selected_weight) - - bcast_weight_sum = _reduction( - selected_weight, - prims.sum, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - out = reduced_sum / bcast_weight_sum - if label_smoothing > 0: - ret = ret / bcast_weight_sum - elif ignore_index >= 0: - mask_sum = _reduction( - mask, - prims.sum, - dtype=to_dtype(torch.float), - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, - ) - # NOTE: does the call numel here work with dynamic shape?! - out = reduced_sum / (target.numel() - mask_sum) - if label_smoothing > 0: - ret = ret / (target.numel() - mask_sum) - elif target.ndim == 0: - # NOTE: this is pytorch implementation details. - # overwrite output to 0 when target hits ignore_index AND label_smoothing is missing. - # https://github.com/pytorch/pytorch/pull/64572 - out = where((target == ignore_index), 0, out) - else: - out = reduced_sum / target.numel() - if label_smoothing > 0: - ret = ret / target.numel() - else: - raise ValueError(f"Reduction argument: {reduction} to cross_entropy is not supported") + if weight is not None: + bcast_weight = reshape(weight, [num_channels] + [1 for _ in range(2, a.ndim)]) + out = out * bcast_weight - # TODO FIXME This is probably incorrect -- but somewhere above the dtype of out can disagree with PyTorch - out = out.to(a.dtype) + out = -out - if label_smoothing > 0: - return out * (1 - label_smoothing) + (ret * (label_smoothing / C)) - else: - return out + if reduction == "none": + return sum(out, dim=channels_dim) + elif reduction == "sum": + return sum(out) + elif reduction == "mean": + return sum(out) / (a.numel() // num_channels) -# TODO The function cross_entropy_backward shouldn't be registered as a primitive operation (above), but as -# a composite operation -def _cross_entropy_grad( +def _cross_entropy_loss_label_smoothing( a: TensorLike, /, target: TensorLike, - weight: None | TensorLike = None, - size_average: None | Any = None, - ignore_index: int = -100, - reduce: None | Any = None, - reduction: str = "mean", - label_smoothing: float = 0.0, + weight: None | TensorLike, + ignore_index: int, + reduction: str, + label_smoothing: int, ) -> TensorLike: - fwd: TensorLike = cross_entropy(a, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing) + # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # or right next to it (i.e. a.shape[1]). + channels_dim = 1 if a.ndim >= 2 else 0 + num_channels = a.shape[channels_dim] - g: TensorLike = get_grad(fwd) - a_grad: TensorLike = cross_entropy_backward(g, a, target, weight, reduction, ignore_index, label_smoothing) - put_grad(a, a_grad) + log_softmax_value = log_softmax(a, dim=channels_dim) - return fwd + if weight is not None: + bcast_weight = reshape(weight, [num_channels] + [1 for _ in range(2, len(a.shape))]) + out = -(log_softmax_value * bcast_weight) + else: + out = -log_softmax_value + + smooth_loss = sum(out, dim=channels_dim) + + # Make target broadcastable with output, which has same shape as input tensor. + selected_target_mask = target != ignore_index + smooth_loss = where(selected_target_mask, smooth_loss, 0) + + if reduction == "none": + ret = smooth_loss + elif reduction == "sum": + ret = sum(smooth_loss) + elif reduction == "mean": + reduced_sum = sum(out) + if weight is not None: + # Gather the weights for each target class. + # Mask the ignored target classes. + # Sum together all target weights. + # Make target broadcastable with output, which has same shape as input tensor. + expanded_weight = expand(bcast_weight, a.shape) + bcast_target = unsqueeze(target, channels_dim) + selected_weight = take_along_dim(expanded_weight, bcast_target, channels_dim) + selected_weight = where(selected_target_mask, squeeze(selected_weight), 0) + ret = reduced_sum / sum(selected_weight) + else: + # The weight tensor is none, so the total weight is the number of valid target elements not equal to + # ignore_index argument + ret = reduced_sum / sum(selected_target_mask) + nll_loss_value = nll_loss(log_softmax_value, target, weight, ignore_index, reduction) -register_grad(cross_entropy, _cross_entropy_grad) + return (nll_loss_value * (1.0 - label_smoothing)) + (ret * (label_smoothing / num_channels)) # TODO Is this a method?