Skip to content

Commit

Permalink
Add support for torch FP8 dtypes (Lightning-AI#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardofelluga authored May 27, 2024
1 parent 7e23c5a commit a4dcd89
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 24 deletions.
8 changes: 8 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
"int32",
"int64",
"bfloat16",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e4m3fn",
"float8_e4m3fnuz",
"float16",
"float32",
"float64",
Expand Down Expand Up @@ -129,6 +133,10 @@ def __version__():
int32 = dtypes.int32
int64 = dtypes.int64
bfloat16 = dtypes.bfloat16
float8_e5m2 = dtypes.float8_e5m2
float8_e5m2fnuz = dtypes.float8_e5m2fnuz
float8_e4m3fn = dtypes.float8_e4m3fn
float8_e4m3fnuz = dtypes.float8_e4m3fnuz
float16 = dtypes.float16
float32 = dtypes.float32
float64 = dtypes.float64
Expand Down
4 changes: 4 additions & 0 deletions thunder/core/baseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ def indent(level):
torch.int32: "torch.int32",
torch.int64: "torch.int64",
torch.bfloat16: "torch.bfloat16",
torch.float8_e4m3fn: "torch.float8_e4m3fn",
torch.float8_e4m3fnuz: "torch.float8_e4m3fnuz",
torch.float8_e5m2: "torch.float8_e5m2",
torch.float8_e5m2fnuz: "torch.float8_e5m2fnuz",
torch.float16: "torch.float16",
torch.float32: "torch.float32",
torch.float64: "torch.float64",
Expand Down
68 changes: 56 additions & 12 deletions thunder/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def __new__(cls, *args, **kwargs):

return object.__new__(cls)

def __init__(self, *, python_type, name, shortname, bytes, is_weak):
def __init__(self, *, python_type, name, shortname, bytes, is_weak, variant=None):
self._python_type = python_type
self._name = name
self._variant = variant
self._shortname = shortname
self._bytes = bytes
self._is_weak = is_weak
Expand All @@ -80,23 +81,30 @@ def is_weak(self):
return self._is_weak

def shortname(self):
return f"{self._shortname}{8 * self._bytes}"
return f"{self._shortname}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}"

# TODO Fix name printing
def __repr__(self):
return f"{self._name}{8 * self._bytes}{'_' if self._is_weak else ''}"
return (
f"{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}"
)

def __str__(self):
return self.__repr__()

def __hash__(self) -> int:
return hash((self._name, self._bytes, self._is_weak))
return hash((self._name, self._bytes, self._is_weak, f"{self._variant if self._variant else ''}"))

def __eq__(self, other) -> bool:
if not isinstance(other, dtype):
return False

return self._name == other._name and self._bytes == other._bytes and self._is_weak == other._is_weak
return (
self._name == other._name
and self._bytes == other._bytes
and self._is_weak == other._is_weak
and self._variant == other._variant
)


class exact(dtype):
Expand Down Expand Up @@ -152,14 +160,24 @@ class inexact(dtype):


class floating(inexact):
"""Base class for the floating dtypes: bfloat16, float16, float32, float64."""
"""Base class for the floating dtypes: float8, bfloat16, float16, float32, float64."""

def __init__(self, name, shortname, *, bytes, is_weak):
super().__init__(python_type=float, name=name, shortname=shortname, bytes=bytes, is_weak=is_weak)
def __init__(self, name, shortname, *, bytes, is_weak, variant=None):
super().__init__(
python_type=float, name=name, shortname=shortname, bytes=bytes, is_weak=is_weak, variant=variant
)


bfloat16 = floating("bfloat", "bf", bytes=2, is_weak=False)
bfloat16_ = floating("bfloat", "bf", bytes=2, is_weak=True)
float8_e5m2 = floating("float", "f", bytes=1, is_weak=False, variant="e5m2")
float8_e5m2_ = floating("float", "f", bytes=1, is_weak=True, variant="e5m2")
float8_e5m2fnuz = floating("float", "f", bytes=1, is_weak=False, variant="e5m2fnuz")
float8_e5m2fnuz_ = floating("float", "f", bytes=1, is_weak=True, variant="e5m2fnuz")
float8_e4m3fn = floating("float", "f", bytes=1, is_weak=False, variant="e4m3fn")
float8_e4m3fn_ = floating("float", "f", bytes=1, is_weak=True, variant="e4m3fn")
float8_e4m3fnuz = floating("float", "f", bytes=1, is_weak=False, variant="e4m3fnuz")
float8_e4m3fnuz_ = floating("float", "f", bytes=1, is_weak=True, variant="e4m3fnuz")
float16 = floating("float", "f", bytes=2, is_weak=False)
float16_ = floating("float", "f", bytes=2, is_weak=True)
float32 = floating("float", "f", bytes=4, is_weak=False)
Expand Down Expand Up @@ -200,6 +218,14 @@ def __init__(self, name, shortname, *, bytes, is_weak):
int64_,
bfloat16,
bfloat16_,
float8_e5m2,
float8_e5m2_,
float8_e5m2fnuz,
float8_e5m2fnuz_,
float8_e4m3fn,
float8_e4m3fn_,
float8_e4m3fnuz,
float8_e4m3fnuz_,
float16,
float16_,
float32,
Expand Down Expand Up @@ -242,6 +268,10 @@ def __init__(self, name, shortname, *, bytes, is_weak):

float_dtypes = {d for d in all_dtypes if isinstance(d, floating)} | {float}

float_math_dtypes = {d for d in all_dtypes if isinstance(d, floating) and d.bytes >= 2}

float_8bit_dtypes = {d for d in all_dtypes if (isinstance(d, floating) and d.bytes == 1)}

complex_dtypes = {d for d in all_dtypes if isinstance(d, complexfloating)} | {complex}

inexact_dtypes = float_dtypes | complex_dtypes
Expand Down Expand Up @@ -306,11 +336,12 @@ def has_subdtype(x, cls):


# Translates a sequence of dtypes and dtype classes into a concrete set of corresponding (strong) dtypes
def resolve_dtypes(args):
def resolve_dtypes(args: Iterable) -> set[dtype]:
dtypes = set()
for arg in args:
if isinstance(arg, dtype):
dtypes.add(arg)
if not arg.is_weak:
dtypes.add(arg)
continue

if isinstance(arg, Iterable):
Expand All @@ -320,7 +351,8 @@ def resolve_dtypes(args):
lambda: f"Iterables passed to resolve_dtypes must only contain dtypes, but found an Iterable with {a}",
exception_type=NotImplementedError,
)
dtypes.add(a)
if not a.is_weak:
dtypes.add(a)

baseutils.check(
arg in (dtype, exact, signedinteger, unsignedinteger, bool_, inexact, floating, complexfloating),
Expand Down Expand Up @@ -373,6 +405,10 @@ def corresponding_complex_dtype(dtype):
int32: int32_,
int64: int64_,
bfloat16: bfloat16_,
float8_e5m2: float8_e5m2_,
float8_e5m2fnuz: float8_e5m2fnuz_,
float8_e4m3fn: float8_e4m3fn_,
float8_e4m3fnuz: float8_e4m3fnuz_,
float16: float16_,
float32: float32_,
float64: float64_,
Expand Down Expand Up @@ -520,6 +556,14 @@ def are_same_dtypes(a, b, *, weak_and_strong_are_equivalent=True):
int64: torch.int64,
bfloat16_: torch.bfloat16,
bfloat16: torch.bfloat16,
float8_e5m2: torch.float8_e5m2,
float8_e5m2_: torch.float8_e5m2,
float8_e5m2fnuz: torch.float8_e5m2fnuz,
float8_e5m2fnuz_: torch.float8_e5m2fnuz,
float8_e4m3fn: torch.float8_e4m3fn,
float8_e4m3fn_: torch.float8_e4m3fn,
float8_e4m3fnuz: torch.float8_e4m3fnuz,
float8_e4m3fnuz_: torch.float8_e4m3fnuz,
float16_: torch.float16,
float16: torch.float16,
float32_: torch.float32,
Expand Down Expand Up @@ -551,7 +595,7 @@ def to_torch_dtype(x: None | torch.dtype | dtype) -> None | torch.dtype:

# Converts NumPy dtypes to and from thunder dtypes

# NOTE NumPy does not support the bfloat16 or complexhalf (complex32) datatypes
# NOTE NumPy does not support the bfloat16, complexhalf (complex32) or float8 datatypes
_thunder_to_numpy_dtype_map = {
bool: np.bool_,
int: np.int_,
Expand Down
6 changes: 4 additions & 2 deletions thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class nvFuserTestExecutor(TestExecutor):
name = "nvfuser"
supported_devicetypes = (devices.DeviceType.CUDA,)
supported_dtypes = (
datatypes.floating,
*datatypes.float_math_dtypes,
datatypes.bool8,
datatypes.int32,
datatypes.int64,
Expand Down Expand Up @@ -347,7 +347,9 @@ def __init__(
self.supported_devicetypes = set(filter_ci_devicetypes(self.supported_devicetypes))

self.supported_dtypes = (
datatypes.resolve_dtypes(supported_dtypes) if supported_dtypes is not None else datatypes.all_dtypes
datatypes.resolve_dtypes(supported_dtypes)
if supported_dtypes is not None
else datatypes.all_dtypes - datatypes.float_8bit_dtypes
)

if supported_dtypes == NOTHING:
Expand Down
19 changes: 16 additions & 3 deletions thunder/tests/make_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,15 @@ def clamp(a, l, h):
shape = cast(tuple[int, ...], tuple(shape))

_integral_types = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
_floating_8bit_types = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
_floating_types = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
_complex_types = [torch.complex32, torch.complex64, torch.complex128]
if requires_grad and dtype not in _floating_types and dtype not in _complex_types:
if (
requires_grad
and dtype not in _floating_types
and dtype not in _floating_8bit_types
and dtype not in _complex_types
):
raise ValueError("make_tensor: requires_grad must be False for integral dtype")

if dtype is torch.bool:
Expand All @@ -145,10 +151,10 @@ def clamp(a, l, h):
if low == high:
return torch.full(shape, low, device=device, dtype=dtype)
result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload]
elif dtype in _floating_types:
elif dtype in _floating_types + _floating_8bit_types:
ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max)
m_low, m_high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
result = torch.empty(shape, device=device, dtype=dtype)
result = torch.empty(shape, device=device, dtype=dtype if dtype not in _floating_8bit_types else torch.float32)
_uniform_random(result, m_low, m_high)
elif dtype in _complex_types:
float_dtype = complex_to_corresponding_float_type_map[dtype]
Expand All @@ -175,6 +181,8 @@ def clamp(a, l, h):
replace_with = torch.tensor(1, device=device, dtype=dtype)
elif dtype in _floating_types:
replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype)
elif dtype in _floating_8bit_types:
replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=torch.float32)
else: # dtype in _complex_types:
float_dtype = complex_to_corresponding_float_type_map[dtype]
float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype)
Expand All @@ -184,6 +192,11 @@ def clamp(a, l, h):
if dtype in _floating_types + _complex_types:
result.requires_grad = requires_grad

# NOTE This is a workaround. There are so many not supported operations that,
# even creating the test tensors is hard.
if dtype in _floating_8bit_types:
result = result.to(dtype)

return result


Expand Down
6 changes: 3 additions & 3 deletions thunder/tests/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# TODO This test currently ignores the "should_autocast" argument enumerated in it
@instantiate(
dtypes=dtypes.float_dtypes - {float},
dtypes=dtypes.float_math_dtypes,
)
def test_thunder_autocast_transform(executor, device, dtype):
from thunder.core.transforms import autocast
Expand Down Expand Up @@ -65,7 +65,7 @@ def h(a, b, c):

@instantiate(
executors=[TorchExecutor],
dtypes=dtypes.float_dtypes - {float},
dtypes=dtypes.float_math_dtypes,
)
def test_no_autocast(executor, device, dtype):
from thunder.core.symbol import Symbol
Expand Down Expand Up @@ -112,7 +112,7 @@ def func():


@instantiate(
dtypes=dtypes.float_dtypes - {float},
dtypes=dtypes.float_math_dtypes,
decorators=(pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported"),),
)
def test_compile_autocast(executor, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ def upcast_tensors(x: Any) -> Any:

@ops(
tuple(op for op in opinfos if op.supports_grad and op.torch_reference is not None),
supported_dtypes=(dtypes.floating,),
supported_dtypes=dtypes.float_math_dtypes,
)
def test_phantom_grad_vs_torch_consistency(op, device: str, dtype: dtypes.dtype, executor, comp):
if dtypes.is_complex_dtype(dtype):
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from thunder.tests.framework import instantiate, nvFuserExecutor


@instantiate()
@instantiate(dtypes=datatypes.all_dtypes - datatypes.float_8bit_dtypes)
def test_prim_inplace_copy_fwd(executor, device, dtype):
def torch_foo(x, y):
z = x + y
Expand All @@ -37,7 +37,7 @@ def foo(x, y):
assert_close(a, a1)


@instantiate(dtypes=(datatypes.floating,))
@instantiate(dtypes=datatypes.float_math_dtypes)
def test_prim_inplace_copy_bwd(executor, device, dtype):
def torch_foo(x, y):
z = x * y
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def snippet_errors(op, sample, ex_type, err_msg_match=None):


@ops(tuple(op for op in opinfos if op.error_input_generator is not None))
def test_errors(op, device, _, executor, comp):
def test_errors(op, device, dtype, executor, comp):
for sample, ex_type, err_msg in op.error_inputs(device):
result = run_snippet(snippet_errors, op, device, None, executor.make_callable(op.op), sample, ex_type, err_msg)
if result is not None:
Expand Down

0 comments on commit a4dcd89

Please sign in to comment.