From a4dcd899d7e6ae91927bd696f8b1efbad7e6bfb4 Mon Sep 17 00:00:00 2001 From: Riccardo Felluga <11768013+riccardofelluga@users.noreply.github.com> Date: Mon, 27 May 2024 16:22:52 +0300 Subject: [PATCH] Add support for torch FP8 dtypes (#445) --- thunder/__init__.py | 8 ++++ thunder/core/baseutils.py | 4 ++ thunder/core/dtypes.py | 68 ++++++++++++++++++++++++------ thunder/tests/framework.py | 6 ++- thunder/tests/make_tensor.py | 19 +++++++-- thunder/tests/test_autocast.py | 6 +-- thunder/tests/test_grad.py | 2 +- thunder/tests/test_inplace_copy.py | 4 +- thunder/tests/test_ops.py | 2 +- 9 files changed, 95 insertions(+), 24 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index e545564b1e..51c8c98778 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -92,6 +92,10 @@ "int32", "int64", "bfloat16", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e4m3fn", + "float8_e4m3fnuz", "float16", "float32", "float64", @@ -129,6 +133,10 @@ def __version__(): int32 = dtypes.int32 int64 = dtypes.int64 bfloat16 = dtypes.bfloat16 +float8_e5m2 = dtypes.float8_e5m2 +float8_e5m2fnuz = dtypes.float8_e5m2fnuz +float8_e4m3fn = dtypes.float8_e4m3fn +float8_e4m3fnuz = dtypes.float8_e4m3fnuz float16 = dtypes.float16 float32 = dtypes.float32 float64 = dtypes.float64 diff --git a/thunder/core/baseutils.py b/thunder/core/baseutils.py index 3a49794b89..d2f3bb4d96 100644 --- a/thunder/core/baseutils.py +++ b/thunder/core/baseutils.py @@ -304,6 +304,10 @@ def indent(level): torch.int32: "torch.int32", torch.int64: "torch.int64", torch.bfloat16: "torch.bfloat16", + torch.float8_e4m3fn: "torch.float8_e4m3fn", + torch.float8_e4m3fnuz: "torch.float8_e4m3fnuz", + torch.float8_e5m2: "torch.float8_e5m2", + torch.float8_e5m2fnuz: "torch.float8_e5m2fnuz", torch.float16: "torch.float16", torch.float32: "torch.float32", torch.float64: "torch.float64", diff --git a/thunder/core/dtypes.py b/thunder/core/dtypes.py index af66e1fd3e..0a1e24701e 100644 --- a/thunder/core/dtypes.py +++ b/thunder/core/dtypes.py @@ -59,9 +59,10 @@ def __new__(cls, *args, **kwargs): return object.__new__(cls) - def __init__(self, *, python_type, name, shortname, bytes, is_weak): + def __init__(self, *, python_type, name, shortname, bytes, is_weak, variant=None): self._python_type = python_type self._name = name + self._variant = variant self._shortname = shortname self._bytes = bytes self._is_weak = is_weak @@ -80,23 +81,30 @@ def is_weak(self): return self._is_weak def shortname(self): - return f"{self._shortname}{8 * self._bytes}" + return f"{self._shortname}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}" # TODO Fix name printing def __repr__(self): - return f"{self._name}{8 * self._bytes}{'_' if self._is_weak else ''}" + return ( + f"{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}" + ) def __str__(self): return self.__repr__() def __hash__(self) -> int: - return hash((self._name, self._bytes, self._is_weak)) + return hash((self._name, self._bytes, self._is_weak, f"{self._variant if self._variant else ''}")) def __eq__(self, other) -> bool: if not isinstance(other, dtype): return False - return self._name == other._name and self._bytes == other._bytes and self._is_weak == other._is_weak + return ( + self._name == other._name + and self._bytes == other._bytes + and self._is_weak == other._is_weak + and self._variant == other._variant + ) class exact(dtype): @@ -152,14 +160,24 @@ class inexact(dtype): class floating(inexact): - """Base class for the floating dtypes: bfloat16, float16, float32, float64.""" + """Base class for the floating dtypes: float8, bfloat16, float16, float32, float64.""" - def __init__(self, name, shortname, *, bytes, is_weak): - super().__init__(python_type=float, name=name, shortname=shortname, bytes=bytes, is_weak=is_weak) + def __init__(self, name, shortname, *, bytes, is_weak, variant=None): + super().__init__( + python_type=float, name=name, shortname=shortname, bytes=bytes, is_weak=is_weak, variant=variant + ) bfloat16 = floating("bfloat", "bf", bytes=2, is_weak=False) bfloat16_ = floating("bfloat", "bf", bytes=2, is_weak=True) +float8_e5m2 = floating("float", "f", bytes=1, is_weak=False, variant="e5m2") +float8_e5m2_ = floating("float", "f", bytes=1, is_weak=True, variant="e5m2") +float8_e5m2fnuz = floating("float", "f", bytes=1, is_weak=False, variant="e5m2fnuz") +float8_e5m2fnuz_ = floating("float", "f", bytes=1, is_weak=True, variant="e5m2fnuz") +float8_e4m3fn = floating("float", "f", bytes=1, is_weak=False, variant="e4m3fn") +float8_e4m3fn_ = floating("float", "f", bytes=1, is_weak=True, variant="e4m3fn") +float8_e4m3fnuz = floating("float", "f", bytes=1, is_weak=False, variant="e4m3fnuz") +float8_e4m3fnuz_ = floating("float", "f", bytes=1, is_weak=True, variant="e4m3fnuz") float16 = floating("float", "f", bytes=2, is_weak=False) float16_ = floating("float", "f", bytes=2, is_weak=True) float32 = floating("float", "f", bytes=4, is_weak=False) @@ -200,6 +218,14 @@ def __init__(self, name, shortname, *, bytes, is_weak): int64_, bfloat16, bfloat16_, + float8_e5m2, + float8_e5m2_, + float8_e5m2fnuz, + float8_e5m2fnuz_, + float8_e4m3fn, + float8_e4m3fn_, + float8_e4m3fnuz, + float8_e4m3fnuz_, float16, float16_, float32, @@ -242,6 +268,10 @@ def __init__(self, name, shortname, *, bytes, is_weak): float_dtypes = {d for d in all_dtypes if isinstance(d, floating)} | {float} +float_math_dtypes = {d for d in all_dtypes if isinstance(d, floating) and d.bytes >= 2} + +float_8bit_dtypes = {d for d in all_dtypes if (isinstance(d, floating) and d.bytes == 1)} + complex_dtypes = {d for d in all_dtypes if isinstance(d, complexfloating)} | {complex} inexact_dtypes = float_dtypes | complex_dtypes @@ -306,11 +336,12 @@ def has_subdtype(x, cls): # Translates a sequence of dtypes and dtype classes into a concrete set of corresponding (strong) dtypes -def resolve_dtypes(args): +def resolve_dtypes(args: Iterable) -> set[dtype]: dtypes = set() for arg in args: if isinstance(arg, dtype): - dtypes.add(arg) + if not arg.is_weak: + dtypes.add(arg) continue if isinstance(arg, Iterable): @@ -320,7 +351,8 @@ def resolve_dtypes(args): lambda: f"Iterables passed to resolve_dtypes must only contain dtypes, but found an Iterable with {a}", exception_type=NotImplementedError, ) - dtypes.add(a) + if not a.is_weak: + dtypes.add(a) baseutils.check( arg in (dtype, exact, signedinteger, unsignedinteger, bool_, inexact, floating, complexfloating), @@ -373,6 +405,10 @@ def corresponding_complex_dtype(dtype): int32: int32_, int64: int64_, bfloat16: bfloat16_, + float8_e5m2: float8_e5m2_, + float8_e5m2fnuz: float8_e5m2fnuz_, + float8_e4m3fn: float8_e4m3fn_, + float8_e4m3fnuz: float8_e4m3fnuz_, float16: float16_, float32: float32_, float64: float64_, @@ -520,6 +556,14 @@ def are_same_dtypes(a, b, *, weak_and_strong_are_equivalent=True): int64: torch.int64, bfloat16_: torch.bfloat16, bfloat16: torch.bfloat16, + float8_e5m2: torch.float8_e5m2, + float8_e5m2_: torch.float8_e5m2, + float8_e5m2fnuz: torch.float8_e5m2fnuz, + float8_e5m2fnuz_: torch.float8_e5m2fnuz, + float8_e4m3fn: torch.float8_e4m3fn, + float8_e4m3fn_: torch.float8_e4m3fn, + float8_e4m3fnuz: torch.float8_e4m3fnuz, + float8_e4m3fnuz_: torch.float8_e4m3fnuz, float16_: torch.float16, float16: torch.float16, float32_: torch.float32, @@ -551,7 +595,7 @@ def to_torch_dtype(x: None | torch.dtype | dtype) -> None | torch.dtype: # Converts NumPy dtypes to and from thunder dtypes -# NOTE NumPy does not support the bfloat16 or complexhalf (complex32) datatypes +# NOTE NumPy does not support the bfloat16, complexhalf (complex32) or float8 datatypes _thunder_to_numpy_dtype_map = { bool: np.bool_, int: np.int_, diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index 478ecfc143..cfbe58432d 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -172,7 +172,7 @@ class nvFuserTestExecutor(TestExecutor): name = "nvfuser" supported_devicetypes = (devices.DeviceType.CUDA,) supported_dtypes = ( - datatypes.floating, + *datatypes.float_math_dtypes, datatypes.bool8, datatypes.int32, datatypes.int64, @@ -347,7 +347,9 @@ def __init__( self.supported_devicetypes = set(filter_ci_devicetypes(self.supported_devicetypes)) self.supported_dtypes = ( - datatypes.resolve_dtypes(supported_dtypes) if supported_dtypes is not None else datatypes.all_dtypes + datatypes.resolve_dtypes(supported_dtypes) + if supported_dtypes is not None + else datatypes.all_dtypes - datatypes.float_8bit_dtypes ) if supported_dtypes == NOTHING: diff --git a/thunder/tests/make_tensor.py b/thunder/tests/make_tensor.py index 27d5ecf909..76442bec62 100644 --- a/thunder/tests/make_tensor.py +++ b/thunder/tests/make_tensor.py @@ -117,9 +117,15 @@ def clamp(a, l, h): shape = cast(tuple[int, ...], tuple(shape)) _integral_types = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + _floating_8bit_types = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] _floating_types = [torch.float16, torch.bfloat16, torch.float32, torch.float64] _complex_types = [torch.complex32, torch.complex64, torch.complex128] - if requires_grad and dtype not in _floating_types and dtype not in _complex_types: + if ( + requires_grad + and dtype not in _floating_types + and dtype not in _floating_8bit_types + and dtype not in _complex_types + ): raise ValueError("make_tensor: requires_grad must be False for integral dtype") if dtype is torch.bool: @@ -145,10 +151,10 @@ def clamp(a, l, h): if low == high: return torch.full(shape, low, device=device, dtype=dtype) result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload] - elif dtype in _floating_types: + elif dtype in _floating_types + _floating_8bit_types: ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max) m_low, m_high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) - result = torch.empty(shape, device=device, dtype=dtype) + result = torch.empty(shape, device=device, dtype=dtype if dtype not in _floating_8bit_types else torch.float32) _uniform_random(result, m_low, m_high) elif dtype in _complex_types: float_dtype = complex_to_corresponding_float_type_map[dtype] @@ -175,6 +181,8 @@ def clamp(a, l, h): replace_with = torch.tensor(1, device=device, dtype=dtype) elif dtype in _floating_types: replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype) + elif dtype in _floating_8bit_types: + replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=torch.float32) else: # dtype in _complex_types: float_dtype = complex_to_corresponding_float_type_map[dtype] float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype) @@ -184,6 +192,11 @@ def clamp(a, l, h): if dtype in _floating_types + _complex_types: result.requires_grad = requires_grad + # NOTE This is a workaround. There are so many not supported operations that, + # even creating the test tensors is hard. + if dtype in _floating_8bit_types: + result = result.to(dtype) + return result diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 644179a989..f219b5d848 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -14,7 +14,7 @@ # TODO This test currently ignores the "should_autocast" argument enumerated in it @instantiate( - dtypes=dtypes.float_dtypes - {float}, + dtypes=dtypes.float_math_dtypes, ) def test_thunder_autocast_transform(executor, device, dtype): from thunder.core.transforms import autocast @@ -65,7 +65,7 @@ def h(a, b, c): @instantiate( executors=[TorchExecutor], - dtypes=dtypes.float_dtypes - {float}, + dtypes=dtypes.float_math_dtypes, ) def test_no_autocast(executor, device, dtype): from thunder.core.symbol import Symbol @@ -112,7 +112,7 @@ def func(): @instantiate( - dtypes=dtypes.float_dtypes - {float}, + dtypes=dtypes.float_math_dtypes, decorators=(pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported"),), ) def test_compile_autocast(executor, device, dtype): diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index c27d57381b..2a43195f74 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1269,7 +1269,7 @@ def upcast_tensors(x: Any) -> Any: @ops( tuple(op for op in opinfos if op.supports_grad and op.torch_reference is not None), - supported_dtypes=(dtypes.floating,), + supported_dtypes=dtypes.float_math_dtypes, ) def test_phantom_grad_vs_torch_consistency(op, device: str, dtype: dtypes.dtype, executor, comp): if dtypes.is_complex_dtype(dtype): diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index f98ba024e3..c7121cea5b 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -10,7 +10,7 @@ from thunder.tests.framework import instantiate, nvFuserExecutor -@instantiate() +@instantiate(dtypes=datatypes.all_dtypes - datatypes.float_8bit_dtypes) def test_prim_inplace_copy_fwd(executor, device, dtype): def torch_foo(x, y): z = x + y @@ -37,7 +37,7 @@ def foo(x, y): assert_close(a, a1) -@instantiate(dtypes=(datatypes.floating,)) +@instantiate(dtypes=datatypes.float_math_dtypes) def test_prim_inplace_copy_bwd(executor, device, dtype): def torch_foo(x, y): z = x * y diff --git a/thunder/tests/test_ops.py b/thunder/tests/test_ops.py index 3130735f33..9535be8d3a 100644 --- a/thunder/tests/test_ops.py +++ b/thunder/tests/test_ops.py @@ -23,7 +23,7 @@ def snippet_errors(op, sample, ex_type, err_msg_match=None): @ops(tuple(op for op in opinfos if op.error_input_generator is not None)) -def test_errors(op, device, _, executor, comp): +def test_errors(op, device, dtype, executor, comp): for sample, ex_type, err_msg in op.error_inputs(device): result = run_snippet(snippet_errors, op, device, None, executor.make_callable(op.op), sample, ex_type, err_msg) if result is not None: