Skip to content

Commit

Permalink
Detailed __repr__ (Lightning-AI#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
k223kim authored Jul 1, 2024
1 parent 0136c67 commit b437cb0
Show file tree
Hide file tree
Showing 17 changed files with 63 additions and 47 deletions.
6 changes: 5 additions & 1 deletion thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def __call__(self, fn: Callable) -> Callable:
@clangop()
def check_tensor_shape_and_metadata(t: TensorProxy, /) -> None:
return prims.check_tensor_shape_and_metadata(
t, tuple(t.shape), str(t.device), dtypes.to_torch_dtype(t.dtype), t.requires_grad
t,
tuple(t.shape),
t.device.device_str(),
dtypes.to_torch_dtype(t.dtype),
t.requires_grad,
)


Expand Down
16 changes: 9 additions & 7 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def prettyprint(
if isinstance(x, dtypes.dtype):
return m(f"dtypes.{str(x)}")
if isinstance(x, devices.Device):
return m(f'devices.Device("{str(x)}")')
return m(f'devices.Device("{x.device_str()}")')
if type(x) is type:
return m(f"{baseutils.print_type(x, with_quotes=False)}")
if dataclasses.is_dataclass(x):
Expand All @@ -243,12 +243,14 @@ def prettyprint(
# NOTE: The `class` packagename1_MyContainer will present in `import_ctx` and passed to the compiled function.
# This is taken care of by function `to_printable`.
name = _generate_dataclass_class_name(x)
instance_repr = str(x)
parens_idx = instance_repr.find("(")
call_repr = instance_repr[
parens_idx:
] # only keep the construction part of the repr (as we will use our generated name)
return m(f"{name + call_repr}")
call_repr = []
for k, v in x.__dict__.items():
try:
call_repr.append(f"{k}={v.name}")
except:
call_repr.append(f"{k}={v}")
call_repr_str = ",".join(call_repr)
return m(f"{name}({call_repr_str})")

# Handles objects that this doesn't know how to serialize as a string
return m(f"(object of type {print_type(type(x), with_quotes=False)})")
Expand Down
15 changes: 12 additions & 3 deletions thunder/core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,21 @@ def __hash__(self) -> int:
# converting Thunder devices to PyTorch devices
def __repr__(self) -> str:
if self.devicetype == DeviceType.CUDA:
return f"{devicetype_string(self.devicetype)}:{self.index}"
return f"thunder.devices.Device(type='{devicetype_string(self.devicetype)}:{self.index}')"
# note: self.devicetype == DeviceType.CPU, .META
return devicetype_string(self.devicetype)
return f"thunder.devices.Device(type='{devicetype_string(self.devicetype)}')"

# NOTE Because devices are singleton object, this has the luxury of using "is"
def __eq__(self, other: Device) -> bool:
return self is other

# NOTE this is needed when passing devices.Device to torch operators such as torch.testing.make_tensor
def device_str(self) -> str:
if self.devicetype == DeviceType.CUDA:
return f"{devicetype_string(self.devicetype)}:{self.index}"
# note: self.devicetype == DeviceType.CPU, .META
return devicetype_string(self.devicetype)


cpu = Device(DeviceType.CPU, None)

Expand Down Expand Up @@ -185,4 +192,6 @@ def to_torch_device(x: None | str | torch.device | Device, /) -> None | torch.de
return x

baseutils.check_type(x, (Device, str))
return torch.device(str(x))
if isinstance(x, Device):
return torch.device(x.device_str())
return torch.device(x)
4 changes: 1 addition & 3 deletions thunder/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def shortname(self):

# TODO Fix name printing
def __repr__(self):
return (
f"{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}"
)
return f"thunder.dtypes.{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}"

def __str__(self):
return self.__repr__()
Expand Down
4 changes: 2 additions & 2 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,14 @@ def assert_tensor_metadata_impl(
if (
type(t) in (torch.Tensor, torch.nn.Parameter)
and tuple(t.shape) == shape
and str(t.device) == str(device)
and str(t.device) == device.device_str()
and t.dtype == dtype
and t.requires_grad == requires_grad
):
return

raise AssertionError(
f"Object had unexpected metadata. Expected type Tensor/nn.Parameter (without subclass), shape {shape}, device {str(device)}, dtype {dtype}, and {requires_grad=}, but found type {type(t)}, shape {tuple(t.shape)}, device {str(t.device)}, and requires_grad {t.requires_grad}"
f"Object had unexpected metadata. Expected type Tensor/nn.Parameter (without subclass), shape {shape}, device {str(device.device_str())}, dtype {dtype}, and {requires_grad=}, but found type {type(t)}, shape {tuple(t.shape)}, device {str(t.device)}, and requires_grad {t.requires_grad}"
)


Expand Down
6 changes: 3 additions & 3 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def replace_name(self, name: str | None = None):
return self.__class__(name=name)

def __repr__(self) -> str:
return f"{self.name}"
return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape}>'

def type_string(self) -> str:
return "Any"
Expand Down Expand Up @@ -1610,7 +1610,7 @@ def real(self):


def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy:
device = devices.device_from_string(str(t.device))
device = devices.to_device(t.device)
dtype = dtypes.to_dtype(t.dtype)
# See Note [DistributedDataParallel and distparallel_type]
distparallel_type = getattr(t, "distparallel_type", None)
Expand All @@ -1631,7 +1631,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
def futuretensorproxy(
t: torch.Tensor | TensorProxy | FutureTensorProxy, /, *, name: None | str, history: None | tuple = None
) -> FutureTensorProxy:
device = devices.device_from_string(str(t.device))
device = devices.to_device(t.device)
dtype = dtypes.to_dtype(t.dtype)
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
return FutureTensorProxy(
Expand Down
4 changes: 2 additions & 2 deletions thunder/distributed/tensor_parallel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def transform_traces(
if c.sym is prims.check_tensor_shape_and_metadata:
# TODO have a more principled way to update this?
a0, _, _, *a2pp = c.args
c.args = (a0, tuple(new_shape), str(a0.device), *a2pp)
c.args = (a0, tuple(new_shape), a0.device.device_str(), *a2pp)

for bsym in prologue_trace.bound_symbols:
if bsym.sym is prims.check_tensor_shape_and_metadata and prologue_producers[bsym.args[0]].sym in (
Expand All @@ -249,7 +249,7 @@ def transform_traces(
assert param_thunder_module is thunder_module_proxy
if name not in self.chunked_param_name_to_layer_type:
a0, shape, _, *a2pp = bsym.args
bsym.args = (a0, shape, str(a0.device), *a2pp)
bsym.args = (a0, shape, a0.device.device_str(), *a2pp)

if len(modules_and_thunder_modules) != 1:
raise NotImplementedError("cannot deal with modules other than the compiled module")
Expand Down
2 changes: 1 addition & 1 deletion thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **
param_name_to_comp_trc_proxy[param_name] = comp_inp_p
old_shape, new_shape, new_torch_device = self.sharded_params[param_name]
thunder_device = devices.to_device(new_torch_device)
thunder_device_str = str(thunder_device)
thunder_device_str = thunder_device.device_str()

pro_out_p._distparallel_type = DistParallelType.FULLY_SHARDED
pro_out_p._shape = tuple(new_shape)
Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _to_transform(


def _device_put_transform(a: TensorProxy, device: devices.Device) -> TensorProxy:
torch_device: str = str(device)
torch_device: str = device.device_str()
return to(a, torch_device)


Expand Down
2 changes: 1 addition & 1 deletion thunder/extend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def implmap(self) -> dict[Hashable, ImplInfo]:
return self._implmap

def __repr__(self) -> str:
return str(self.name)
return f"thunder.extend.OperatorExecutor('{str(self.name)}')"

def __hash__(self) -> int:
return hash(self.name)
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,6 @@ def create_per_process_dataloader(
sampler = tudata.SequentialSampler(dataset)

collate_fn = None

if devicetype is not devices.DeviceType.CPU:
assert devicetype is devices.DeviceType.CUDA, f"Unknown devicetype {devicetype}"
device = torch.device("cuda", rank)
Expand Down Expand Up @@ -1214,6 +1213,7 @@ def _test_native_ddp_helper(input_data):
torch_dtype = ltorch.to_torch_dtype(dtype)

pg = init_per_process_distributed(init_method, devicetype, world_size, rank)

tdist.barrier(pg)

dataloader = create_per_process_dataloader(
Expand Down
5 changes: 3 additions & 2 deletions thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,15 @@ def _instantiate_executor_test_template(
) -> Callable:
devicetype: devices.DeviceType
device_str: str | list[str]
devicetype = device_or_devices
if isinstance(device_or_devices, devices.Device):
devicetype = device_or_devices.devicetype
device_str = str(device_or_devices)
device_str = device_or_devices.device_str()
else:
devicetype = device_or_devices[0].devicetype
device_str = []
for device in device_or_devices:
device_str.append(str(device))
device_str.append(device.device_str())

devicetype_str = devices.devicetype_string(devicetype)
template_name = as_name if as_name is not None else template.__name__
Expand Down
8 changes: 6 additions & 2 deletions thunder/tests/make_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import cast, List, Optional, Tuple, Union

import torch
import thunder

# adapted from https://github.com/pytorch/pytorch/blob/master/torch/testing/_creation.py
# Changes:
Expand Down Expand Up @@ -32,7 +33,7 @@ def _uniform_random(t: torch.Tensor, low: float, high: float):
def make_tensor(
*shape: int | torch.Size | list[int] | tuple[int, ...],
dtype: torch.dtype,
device: str | torch.device,
device: str | torch.device | thunder.devices.Device,
low: float | None = None,
high: float | None = None,
requires_grad: bool = False,
Expand Down Expand Up @@ -62,7 +63,7 @@ def make_tensor(
Args:
shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor.
dtype (:class:`torch.dtype`): The data type of the returned tensor.
device (Union[str, torch.device]): The device of the returned tensor.
device (Union[str, torch.device, thunder.devices.Device]): The device of the returned tensor.
low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is
clamped to the least representable finite value of the given dtype. When ``None`` (default),
this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
Expand Down Expand Up @@ -112,6 +113,9 @@ def clamp(a, l, h):

return low, high

if isinstance(device, thunder.devices.Device):
device = device.device_str()

if len(shape) == 1 and isinstance(shape[0], collections.abc.Sequence):
shape = shape[0] # type: ignore[assignment]
shape = cast(tuple[int, ...], tuple(shape))
Expand Down
24 changes: 10 additions & 14 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def is_active(
# Acquires devicetype
devicetype_: devices.DeviceType
if isinstance(device_or_devicetype, str):
devicetype_ = devices.device_from_string(device_or_devicetype).devicetype
devicetype_ = devices.to_device(device_or_devicetype).devicetype
elif isinstance(device_or_devicetype, devices.Device):
devicetype_ = device_or_devicetype.devicetype
else:
Expand Down Expand Up @@ -392,26 +392,22 @@ def sample_inputs(
self, device: str | devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs
) -> Generator:
torch_dtype = to_torch_dtype(dtype)
torch_device = str(device)
return self.sample_input_generator(self, torch_device, torch_dtype, requires_grad, **kwargs)
return self.sample_input_generator(self, device, torch_dtype, requires_grad, **kwargs)

def reference_inputs(
self, device: str | devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs
) -> Generator:
torch_dtype = to_torch_dtype(dtype)
torch_device = str(device)
return self.reference_input_generator(self, torch_device, torch_dtype, requires_grad, **kwargs)
return self.reference_input_generator(self, device, torch_dtype, requires_grad, **kwargs)

def error_inputs(self, device: devices.Device, **kwargs):
torch_device = str(device)
return self.error_input_generator(self, torch_device, **kwargs)
return self.error_input_generator(self, device, **kwargs)

# NOTE Today all benchmarks are generated with PyTorch, so Thunder objects,
# like dtypes, need to be translated into PyTorch objects
def benchmarks(self, device: devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs):
torch_dtype = to_torch_dtype(dtype)
torch_device = str(device)
return self.benchmark_generator(self, torch_device, dtype, requires_grad, **kwargs)
return self.benchmark_generator(self, device, dtype, requires_grad, **kwargs)

def devicetypes(self):
return set(self._devicetypes)
Expand Down Expand Up @@ -5565,7 +5561,7 @@ def full_sample_generator(op, device, dtype, requires_grad, **kwargs):


def full_error_generator(op, device, **kwargs):
err_msg = "Can't safely cast fill_value of numbertype <class 'complex'> to dtype float32"
err_msg = "Can't safely cast fill_value of numbertype <class 'complex'> to dtype thunder.dtypes.float32"
yield (SampleInput((1, 2), 1j, device=device, dtype=torch.float), RuntimeError, err_msg)


Expand Down Expand Up @@ -5744,7 +5740,7 @@ def bernoulli_sample_generator(op, device, dtype, requires_grad, **kwargs):


def bernoulli_error_generator(op, device, **kwargs):
err_msg = "bernoulli only supports floating point dtypes, got int64"
err_msg = "bernoulli only supports floating point dtypes, got thunder.dtypes.int64"
yield (SampleInput(torch.ones(3, 3, device=device, dtype=torch.long)), RuntimeError, err_msg)

err_msg = "generator is not None which is currently unsupported"
Expand Down Expand Up @@ -5903,13 +5899,13 @@ def tensor_constructor_error_generator(op, device, **kwargs):
err_msg = "Expected sequences of numbers, but found type <class 'list'>"
yield (SampleInput([[1], [[6, 2]]]), ValueError, err_msg)

err_msg = "Can't safely cast sequence with numbertype <class 'float'> to dtype int32"
err_msg = "Can't safely cast sequence with numbertype <class 'float'> to dtype thunder.dtypes.int32"
yield (SampleInput([[1, 2.0], [6, 2]], dtype=torch.int32), RuntimeError, err_msg)

err_msg = "Can't safely cast sequence with numbertype <class 'complex'> to dtype int32"
err_msg = "Can't safely cast sequence with numbertype <class 'complex'> to dtype thunder.dtypes.int32"
yield (SampleInput([[1, 2j], [6, 2]], dtype=torch.int32), RuntimeError, err_msg)

err_msg = "Can't safely cast sequence with numbertype <class 'complex'> to dtype float64"
err_msg = "Can't safely cast sequence with numbertype <class 'complex'> to dtype thunder.dtypes.float64"
yield (SampleInput([[1, 2j], [6, 2]], dtype=torch.float64), RuntimeError, err_msg)


Expand Down
4 changes: 3 additions & 1 deletion thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,9 @@ def foo(x, y, z):
consumers = thunder.core.utils.consumers(trace)
region_bsyms = trace.bound_symbols[:3]
region = Region(producers, consumers, region_bsyms)
assert len(region.inputs) == 0 and sorted(str(v) for v in region.outputs) == ["t0"]
assert len(region.inputs) == 0 and sorted(str(v) for v in region.outputs) == [
'<TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(1,)>'
]


# This test ensures that calls to torch functions are recorded in the trace
Expand Down
4 changes: 2 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def type(a: TensorLike, /, dtype: None | str | dtypeLike = None, non_blocking: b
# 2. When a tensor is on a CPU device and the device type string is omitted, the tensor remains on the CPU device.
dev = a.device
else:
dev = device_from_string(devtype)
dev = to_device(devtype)
else:
# dtype is assumed to be torch.dtype (e.g. torch.int32)
dev = a.device
Expand Down Expand Up @@ -458,7 +458,7 @@ def cuda(
device = to_device(device)
utils.check(
device.devicetype == devices.DeviceType.CUDA,
lambda: f"cuda(): Invalid device {device}, must be cuda device",
lambda: f"cuda(): Invalid device {device.device_str()}, must be cuda device",
)

return to(a, device=device, memory_format=memory_format)
Expand Down
2 changes: 1 addition & 1 deletion thunder/transforms/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **
# check has args: tensor, shape, device, dtype, requires_grad
proxy, _, _, _, requires_grad = check.args
thunder_device = thunder.devices.to_device(param.device)
thunder_device_str = str(thunder_device)
thunder_device_str = thunder_device.device_str()
check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False)

output_idx = output_idxes.get(id(get_param.output))
Expand Down

0 comments on commit b437cb0

Please sign in to comment.