diff --git a/tests/test_generator.py b/tests/test_generator.py index e2cf3c19..4f0c90e6 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -2,10 +2,14 @@ from pytest import fixture from mofa.generator import train_generator, run_generator import numpy as np +import torch +from @fixture() def load_model(): - ... + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ddpm = DDPM.load_from_checkpoint(model, map_location=device).eval().to(device) + return ddpm def test_training(): ...