diff --git a/tests/test_generator.py b/tests/test_generator.py index f5a2aed7..90f998fe 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -36,7 +36,13 @@ def test_training(): ... # https://docs.pytest.org/en/7.1.x/how-to/parametrize.html -# @mark.parametrize('n_atoms', [3, 4]) -# def test_sampling_num_atoms(n_atoms): -# run_generator(n_atoms=n_atoms) +@mark.parametrize('n_atoms', [3, 4]) +def test_sampling_num_atoms(n_atoms): + run_generator(n_atoms=n_atoms) + +@mark.parametrize('n_atoms', [3]) +@mark.parametrize('node', ['CuCu', 'ZnZn', 'ZnOZnZnZn']) +@mark.parametrize('n_samples', [1, 3]) +def test_sampling_num_atoms(n_atoms, node, n_samples): + run_generator(n_atoms=n_atoms, node=node, n_samples=n_samples)