From d1c7cba62e8a44b2fd24b47c6437d1fdc75bfd33 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Thu, 12 Sep 2024 16:59:47 +0800 Subject: [PATCH] Update utils.py --- model_compression_toolkit/core/pytorch/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/model_compression_toolkit/core/pytorch/utils.py b/model_compression_toolkit/core/pytorch/utils.py index bb8a15a8e..29f4b74b8 100644 --- a/model_compression_toolkit/core/pytorch/utils.py +++ b/model_compression_toolkit/core/pytorch/utils.py @@ -38,10 +38,7 @@ def set_model(model: torch.nn.Module, train_mode: bool = False): model.eval() device = get_working_device() - if device.type == CPU: - model.cpu() - else: - model.cuda() + model.to(device) def to_torch_tensor(tensor, @@ -112,4 +109,4 @@ def clip_inf_values_float16(tensor: Tensor) -> Tensor: # Replace inf values with max float16 value tensor[inf_mask] = MAX_FLOAT16 * torch.sign(tensor[inf_mask]) - return tensor \ No newline at end of file + return tensor