Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Laughing-q committed Sep 12, 2024
1 parent 1b3854f commit d1c7cba
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions model_compression_toolkit/core/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ def set_model(model: torch.nn.Module, train_mode: bool = False):
model.eval()

device = get_working_device()
if device.type == CPU:
model.cpu()
else:
model.cuda()
model.to(device)


def to_torch_tensor(tensor,
Expand Down Expand Up @@ -112,4 +109,4 @@ def clip_inf_values_float16(tensor: Tensor) -> Tensor:
# Replace inf values with max float16 value
tensor[inf_mask] = MAX_FLOAT16 * torch.sign(tensor[inf_mask])

return tensor
return tensor

0 comments on commit d1c7cba

Please sign in to comment.