From 76557d8593914967a9ff6732e3ac00122efdf178 Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Wed, 28 Aug 2024 17:19:46 +0200 Subject: [PATCH 1/4] fixed device selection --- model_compression_toolkit/core/pytorch/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model_compression_toolkit/core/pytorch/utils.py b/model_compression_toolkit/core/pytorch/utils.py index fc98e3b58..bb8a15a8e 100644 --- a/model_compression_toolkit/core/pytorch/utils.py +++ b/model_compression_toolkit/core/pytorch/utils.py @@ -20,7 +20,7 @@ from model_compression_toolkit.core.pytorch.constants import MAX_FLOAT16, MIN_FLOAT16 from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device from model_compression_toolkit.logger import Logger - +from model_compression_toolkit.core.pytorch.constants import CPU def set_model(model: torch.nn.Module, train_mode: bool = False): """ @@ -38,7 +38,10 @@ def set_model(model: torch.nn.Module, train_mode: bool = False): model.eval() device = get_working_device() - model.to(device) + if device.type == CPU: + model.cpu() + else: + model.cuda() def to_torch_tensor(tensor, From e073ad5feb5d8940b242d554da8526d24740c329 Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Wed, 28 Aug 2024 17:23:11 +0200 Subject: [PATCH 2/4] added faster import --- model_compression_toolkit/core/pytorch/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model_compression_toolkit/core/pytorch/__init__.py b/model_compression_toolkit/core/pytorch/__init__.py index 424152e58..a2d9c0f55 100644 --- a/model_compression_toolkit/core/pytorch/__init__.py +++ b/model_compression_toolkit/core/pytorch/__init__.py @@ -11,4 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== \ No newline at end of file +# ============================================================================== +from model_compression_toolkit.core.pytorch.pytorch_device_config import DeviceManager \ No newline at end of file From 234d8dda15bfdef44ff774e6a31d5b17fe50fd45 Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Wed, 28 Aug 2024 17:29:15 +0200 Subject: [PATCH 3/4] fixed import --- model_compression_toolkit/core/pytorch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_compression_toolkit/core/pytorch/__init__.py b/model_compression_toolkit/core/pytorch/__init__.py index a2d9c0f55..a4d991ff0 100644 --- a/model_compression_toolkit/core/pytorch/__init__.py +++ b/model_compression_toolkit/core/pytorch/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from model_compression_toolkit.core.pytorch.pytorch_device_config import DeviceManager \ No newline at end of file +from model_compression_toolkit.core.pytorch.pytorch_device_config import set_working_device \ No newline at end of file From 9bfc58f5a13bba8dddd5eeeeefe4128ad39605df Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Mon, 2 Sep 2024 10:22:18 +0200 Subject: [PATCH 4/4] Update model_compression_toolkit/core/pytorch/__init__.py --- model_compression_toolkit/core/pytorch/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/model_compression_toolkit/core/pytorch/__init__.py b/model_compression_toolkit/core/pytorch/__init__.py index a4d991ff0..424152e58 100644 --- a/model_compression_toolkit/core/pytorch/__init__.py +++ b/model_compression_toolkit/core/pytorch/__init__.py @@ -11,5 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -from model_compression_toolkit.core.pytorch.pytorch_device_config import set_working_device \ No newline at end of file +# ============================================================================== \ No newline at end of file