From b437cb013b4f35c0e5a0e74f03c2ce28a8187523 Mon Sep 17 00:00:00 2001 From: Kaeun Kim <51257208+k223kim@users.noreply.github.com> Date: Mon, 1 Jul 2024 23:55:40 +0900 Subject: [PATCH] Detailed `__repr__` (#638) --- thunder/clang/__init__.py | 6 ++++- thunder/core/codeutils.py | 16 +++++++------ thunder/core/devices.py | 15 +++++++++--- thunder/core/dtypes.py | 4 +--- thunder/core/prims.py | 4 ++-- thunder/core/proxies.py | 6 ++--- thunder/distributed/tensor_parallel/common.py | 4 ++-- thunder/distributed/transforms/fsdp_v2.py | 2 +- thunder/executors/torchex.py | 2 +- thunder/extend/__init__.py | 2 +- thunder/tests/distributed/test_ddp.py | 2 +- thunder/tests/framework.py | 5 ++-- thunder/tests/make_tensor.py | 8 +++++-- thunder/tests/opinfos.py | 24 ++++++++----------- thunder/tests/test_core.py | 4 +++- thunder/torch/__init__.py | 4 ++-- thunder/transforms/quantization.py | 2 +- 17 files changed, 63 insertions(+), 47 deletions(-) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index e9bbf7ac2a..c12365ba80 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -69,7 +69,11 @@ def __call__(self, fn: Callable) -> Callable: @clangop() def check_tensor_shape_and_metadata(t: TensorProxy, /) -> None: return prims.check_tensor_shape_and_metadata( - t, tuple(t.shape), str(t.device), dtypes.to_torch_dtype(t.dtype), t.requires_grad + t, + tuple(t.shape), + t.device.device_str(), + dtypes.to_torch_dtype(t.dtype), + t.requires_grad, ) diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 91b135d494..2c74cc8053 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -229,7 +229,7 @@ def prettyprint( if isinstance(x, dtypes.dtype): return m(f"dtypes.{str(x)}") if isinstance(x, devices.Device): - return m(f'devices.Device("{str(x)}")') + return m(f'devices.Device("{x.device_str()}")') if type(x) is type: return m(f"{baseutils.print_type(x, with_quotes=False)}") if dataclasses.is_dataclass(x): @@ -243,12 +243,14 @@ def prettyprint( # NOTE: The `class` packagename1_MyContainer will present in `import_ctx` and passed to the compiled function. # This is taken care of by function `to_printable`. name = _generate_dataclass_class_name(x) - instance_repr = str(x) - parens_idx = instance_repr.find("(") - call_repr = instance_repr[ - parens_idx: - ] # only keep the construction part of the repr (as we will use our generated name) - return m(f"{name + call_repr}") + call_repr = [] + for k, v in x.__dict__.items(): + try: + call_repr.append(f"{k}={v.name}") + except: + call_repr.append(f"{k}={v}") + call_repr_str = ",".join(call_repr) + return m(f"{name}({call_repr_str})") # Handles objects that this doesn't know how to serialize as a string return m(f"(object of type {print_type(type(x), with_quotes=False)})") diff --git a/thunder/core/devices.py b/thunder/core/devices.py index d08dd2826c..70c79a21fb 100644 --- a/thunder/core/devices.py +++ b/thunder/core/devices.py @@ -118,14 +118,21 @@ def __hash__(self) -> int: # converting Thunder devices to PyTorch devices def __repr__(self) -> str: if self.devicetype == DeviceType.CUDA: - return f"{devicetype_string(self.devicetype)}:{self.index}" + return f"thunder.devices.Device(type='{devicetype_string(self.devicetype)}:{self.index}')" # note: self.devicetype == DeviceType.CPU, .META - return devicetype_string(self.devicetype) + return f"thunder.devices.Device(type='{devicetype_string(self.devicetype)}')" # NOTE Because devices are singleton object, this has the luxury of using "is" def __eq__(self, other: Device) -> bool: return self is other + # NOTE this is needed when passing devices.Device to torch operators such as torch.testing.make_tensor + def device_str(self) -> str: + if self.devicetype == DeviceType.CUDA: + return f"{devicetype_string(self.devicetype)}:{self.index}" + # note: self.devicetype == DeviceType.CPU, .META + return devicetype_string(self.devicetype) + cpu = Device(DeviceType.CPU, None) @@ -185,4 +192,6 @@ def to_torch_device(x: None | str | torch.device | Device, /) -> None | torch.de return x baseutils.check_type(x, (Device, str)) - return torch.device(str(x)) + if isinstance(x, Device): + return torch.device(x.device_str()) + return torch.device(x) diff --git a/thunder/core/dtypes.py b/thunder/core/dtypes.py index b180aa3127..c021975bb6 100644 --- a/thunder/core/dtypes.py +++ b/thunder/core/dtypes.py @@ -85,9 +85,7 @@ def shortname(self): # TODO Fix name printing def __repr__(self): - return ( - f"{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}" - ) + return f"thunder.dtypes.{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}" def __str__(self): return self.__repr__() diff --git a/thunder/core/prims.py b/thunder/core/prims.py index d84536ff7f..018ccf152f 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -330,14 +330,14 @@ def assert_tensor_metadata_impl( if ( type(t) in (torch.Tensor, torch.nn.Parameter) and tuple(t.shape) == shape - and str(t.device) == str(device) + and str(t.device) == device.device_str() and t.dtype == dtype and t.requires_grad == requires_grad ): return raise AssertionError( - f"Object had unexpected metadata. Expected type Tensor/nn.Parameter (without subclass), shape {shape}, device {str(device)}, dtype {dtype}, and {requires_grad=}, but found type {type(t)}, shape {tuple(t.shape)}, device {str(t.device)}, and requires_grad {t.requires_grad}" + f"Object had unexpected metadata. Expected type Tensor/nn.Parameter (without subclass), shape {shape}, device {str(device.device_str())}, dtype {dtype}, and {requires_grad=}, but found type {type(t)}, shape {tuple(t.shape)}, device {str(t.device)}, and requires_grad {t.requires_grad}" ) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index c0ff0e90f1..89479bc43f 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -134,7 +134,7 @@ def replace_name(self, name: str | None = None): return self.__class__(name=name) def __repr__(self) -> str: - return f"{self.name}" + return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape}>' def type_string(self) -> str: return "Any" @@ -1610,7 +1610,7 @@ def real(self): def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy: - device = devices.device_from_string(str(t.device)) + device = devices.to_device(t.device) dtype = dtypes.to_dtype(t.dtype) # See Note [DistributedDataParallel and distparallel_type] distparallel_type = getattr(t, "distparallel_type", None) @@ -1631,7 +1631,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = def futuretensorproxy( t: torch.Tensor | TensorProxy | FutureTensorProxy, /, *, name: None | str, history: None | tuple = None ) -> FutureTensorProxy: - device = devices.device_from_string(str(t.device)) + device = devices.to_device(t.device) dtype = dtypes.to_dtype(t.dtype) # NOTE Without tuple(t.shape) then the shape would be a torch.Size object return FutureTensorProxy( diff --git a/thunder/distributed/tensor_parallel/common.py b/thunder/distributed/tensor_parallel/common.py index 9b7770f94b..2f6141c92a 100644 --- a/thunder/distributed/tensor_parallel/common.py +++ b/thunder/distributed/tensor_parallel/common.py @@ -238,7 +238,7 @@ def transform_traces( if c.sym is prims.check_tensor_shape_and_metadata: # TODO have a more principled way to update this? a0, _, _, *a2pp = c.args - c.args = (a0, tuple(new_shape), str(a0.device), *a2pp) + c.args = (a0, tuple(new_shape), a0.device.device_str(), *a2pp) for bsym in prologue_trace.bound_symbols: if bsym.sym is prims.check_tensor_shape_and_metadata and prologue_producers[bsym.args[0]].sym in ( @@ -249,7 +249,7 @@ def transform_traces( assert param_thunder_module is thunder_module_proxy if name not in self.chunked_param_name_to_layer_type: a0, shape, _, *a2pp = bsym.args - bsym.args = (a0, shape, str(a0.device), *a2pp) + bsym.args = (a0, shape, a0.device.device_str(), *a2pp) if len(modules_and_thunder_modules) != 1: raise NotImplementedError("cannot deal with modules other than the compiled module") diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index 71c51e7698..bc1236e0c1 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -61,7 +61,7 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, ** param_name_to_comp_trc_proxy[param_name] = comp_inp_p old_shape, new_shape, new_torch_device = self.sharded_params[param_name] thunder_device = devices.to_device(new_torch_device) - thunder_device_str = str(thunder_device) + thunder_device_str = thunder_device.device_str() pro_out_p._distparallel_type = DistParallelType.FULLY_SHARDED pro_out_p._shape = tuple(new_shape) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index aa4d741052..c49ff95358 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -132,7 +132,7 @@ def _to_transform( def _device_put_transform(a: TensorProxy, device: devices.Device) -> TensorProxy: - torch_device: str = str(device) + torch_device: str = device.device_str() return to(a, torch_device) diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index dac6ff5953..e525c2e8a4 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -74,7 +74,7 @@ def implmap(self) -> dict[Hashable, ImplInfo]: return self._implmap def __repr__(self) -> str: - return str(self.name) + return f"thunder.extend.OperatorExecutor('{str(self.name)}')" def __hash__(self) -> int: return hash(self.name) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index e255a0b7ac..a2f086062e 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -1110,7 +1110,6 @@ def create_per_process_dataloader( sampler = tudata.SequentialSampler(dataset) collate_fn = None - if devicetype is not devices.DeviceType.CPU: assert devicetype is devices.DeviceType.CUDA, f"Unknown devicetype {devicetype}" device = torch.device("cuda", rank) @@ -1214,6 +1213,7 @@ def _test_native_ddp_helper(input_data): torch_dtype = ltorch.to_torch_dtype(dtype) pg = init_per_process_distributed(init_method, devicetype, world_size, rank) + tdist.barrier(pg) dataloader = create_per_process_dataloader( diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index 6826e81960..adaae6d897 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -276,14 +276,15 @@ def _instantiate_executor_test_template( ) -> Callable: devicetype: devices.DeviceType device_str: str | list[str] + devicetype = device_or_devices if isinstance(device_or_devices, devices.Device): devicetype = device_or_devices.devicetype - device_str = str(device_or_devices) + device_str = device_or_devices.device_str() else: devicetype = device_or_devices[0].devicetype device_str = [] for device in device_or_devices: - device_str.append(str(device)) + device_str.append(device.device_str()) devicetype_str = devices.devicetype_string(devicetype) template_name = as_name if as_name is not None else template.__name__ diff --git a/thunder/tests/make_tensor.py b/thunder/tests/make_tensor.py index 76442bec62..763b7a842b 100644 --- a/thunder/tests/make_tensor.py +++ b/thunder/tests/make_tensor.py @@ -3,6 +3,7 @@ from typing import cast, List, Optional, Tuple, Union import torch +import thunder # adapted from https://github.com/pytorch/pytorch/blob/master/torch/testing/_creation.py # Changes: @@ -32,7 +33,7 @@ def _uniform_random(t: torch.Tensor, low: float, high: float): def make_tensor( *shape: int | torch.Size | list[int] | tuple[int, ...], dtype: torch.dtype, - device: str | torch.device, + device: str | torch.device | thunder.devices.Device, low: float | None = None, high: float | None = None, requires_grad: bool = False, @@ -62,7 +63,7 @@ def make_tensor( Args: shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor. dtype (:class:`torch.dtype`): The data type of the returned tensor. - device (Union[str, torch.device]): The device of the returned tensor. + device (Union[str, torch.device, thunder.devices.Device]): The device of the returned tensor. low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is clamped to the least representable finite value of the given dtype. When ``None`` (default), this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``. @@ -112,6 +113,9 @@ def clamp(a, l, h): return low, high + if isinstance(device, thunder.devices.Device): + device = device.device_str() + if len(shape) == 1 and isinstance(shape[0], collections.abc.Sequence): shape = shape[0] # type: ignore[assignment] shape = cast(tuple[int, ...], tuple(shape)) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index e2ffb18dfc..8841b0d2aa 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -304,7 +304,7 @@ def is_active( # Acquires devicetype devicetype_: devices.DeviceType if isinstance(device_or_devicetype, str): - devicetype_ = devices.device_from_string(device_or_devicetype).devicetype + devicetype_ = devices.to_device(device_or_devicetype).devicetype elif isinstance(device_or_devicetype, devices.Device): devicetype_ = device_or_devicetype.devicetype else: @@ -392,26 +392,22 @@ def sample_inputs( self, device: str | devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs ) -> Generator: torch_dtype = to_torch_dtype(dtype) - torch_device = str(device) - return self.sample_input_generator(self, torch_device, torch_dtype, requires_grad, **kwargs) + return self.sample_input_generator(self, device, torch_dtype, requires_grad, **kwargs) def reference_inputs( self, device: str | devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs ) -> Generator: torch_dtype = to_torch_dtype(dtype) - torch_device = str(device) - return self.reference_input_generator(self, torch_device, torch_dtype, requires_grad, **kwargs) + return self.reference_input_generator(self, device, torch_dtype, requires_grad, **kwargs) def error_inputs(self, device: devices.Device, **kwargs): - torch_device = str(device) - return self.error_input_generator(self, torch_device, **kwargs) + return self.error_input_generator(self, device, **kwargs) # NOTE Today all benchmarks are generated with PyTorch, so Thunder objects, # like dtypes, need to be translated into PyTorch objects def benchmarks(self, device: devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs): torch_dtype = to_torch_dtype(dtype) - torch_device = str(device) - return self.benchmark_generator(self, torch_device, dtype, requires_grad, **kwargs) + return self.benchmark_generator(self, device, dtype, requires_grad, **kwargs) def devicetypes(self): return set(self._devicetypes) @@ -5565,7 +5561,7 @@ def full_sample_generator(op, device, dtype, requires_grad, **kwargs): def full_error_generator(op, device, **kwargs): - err_msg = "Can't safely cast fill_value of numbertype to dtype float32" + err_msg = "Can't safely cast fill_value of numbertype to dtype thunder.dtypes.float32" yield (SampleInput((1, 2), 1j, device=device, dtype=torch.float), RuntimeError, err_msg) @@ -5744,7 +5740,7 @@ def bernoulli_sample_generator(op, device, dtype, requires_grad, **kwargs): def bernoulli_error_generator(op, device, **kwargs): - err_msg = "bernoulli only supports floating point dtypes, got int64" + err_msg = "bernoulli only supports floating point dtypes, got thunder.dtypes.int64" yield (SampleInput(torch.ones(3, 3, device=device, dtype=torch.long)), RuntimeError, err_msg) err_msg = "generator is not None which is currently unsupported" @@ -5903,13 +5899,13 @@ def tensor_constructor_error_generator(op, device, **kwargs): err_msg = "Expected sequences of numbers, but found type " yield (SampleInput([[1], [[6, 2]]]), ValueError, err_msg) - err_msg = "Can't safely cast sequence with numbertype to dtype int32" + err_msg = "Can't safely cast sequence with numbertype to dtype thunder.dtypes.int32" yield (SampleInput([[1, 2.0], [6, 2]], dtype=torch.int32), RuntimeError, err_msg) - err_msg = "Can't safely cast sequence with numbertype to dtype int32" + err_msg = "Can't safely cast sequence with numbertype to dtype thunder.dtypes.int32" yield (SampleInput([[1, 2j], [6, 2]], dtype=torch.int32), RuntimeError, err_msg) - err_msg = "Can't safely cast sequence with numbertype to dtype float64" + err_msg = "Can't safely cast sequence with numbertype to dtype thunder.dtypes.float64" yield (SampleInput([[1, 2j], [6, 2]], dtype=torch.float64), RuntimeError, err_msg) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 09c5e81baa..f66318c3e2 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -1243,7 +1243,9 @@ def foo(x, y, z): consumers = thunder.core.utils.consumers(trace) region_bsyms = trace.bound_symbols[:3] region = Region(producers, consumers, region_bsyms) - assert len(region.inputs) == 0 and sorted(str(v) for v in region.outputs) == ["t0"] + assert len(region.inputs) == 0 and sorted(str(v) for v in region.outputs) == [ + '' + ] # This test ensures that calls to torch functions are recorded in the trace diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 613443d942..7cc1f17ac3 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -311,7 +311,7 @@ def type(a: TensorLike, /, dtype: None | str | dtypeLike = None, non_blocking: b # 2. When a tensor is on a CPU device and the device type string is omitted, the tensor remains on the CPU device. dev = a.device else: - dev = device_from_string(devtype) + dev = to_device(devtype) else: # dtype is assumed to be torch.dtype (e.g. torch.int32) dev = a.device @@ -458,7 +458,7 @@ def cuda( device = to_device(device) utils.check( device.devicetype == devices.DeviceType.CUDA, - lambda: f"cuda(): Invalid device {device}, must be cuda device", + lambda: f"cuda(): Invalid device {device.device_str()}, must be cuda device", ) return to(a, device=device, memory_format=memory_format) diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index c62f90c16c..140ad78683 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -116,7 +116,7 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, ** # check has args: tensor, shape, device, dtype, requires_grad proxy, _, _, _, requires_grad = check.args thunder_device = thunder.devices.to_device(param.device) - thunder_device_str = str(thunder_device) + thunder_device_str = thunder_device.device_str() check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False) output_idx = output_idxes.get(id(get_param.output))