From 212bda19f07bd15f964a357dd651d9070db85f12 Mon Sep 17 00:00:00 2001 From: hyunp2 <42776897+hyunp2@users.noreply.github.com> Date: Wed, 8 Nov 2023 08:34:16 -0600 Subject: [PATCH] Update test_generator.py --- tests/test_generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_generator.py b/tests/test_generator.py index 059a50f7..ee712fc5 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -8,6 +8,7 @@ @fixture() def load_model(): + model = device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ddpm = DDPM.load_from_checkpoint(model, map_location=device).eval().to(device) return ddpm