Skip to content

Commit

Permalink
Test validation separately
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobson committed Mar 15, 2024
1 parent 1ac5562 commit beff056
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 19 deletions.
4 changes: 2 additions & 2 deletions swmmanywhere/defs/schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ properties:
type: string
enum: [flooding, flow, depth, runoff]
real:
type: object
type: ['object', 'null']
properties:
inp: {type: string}
graph: {type: string}
Expand All @@ -30,4 +30,4 @@ properties:
metric_list: {type: array, items: {type: string}}
address_overrides: {type: ['object', 'null']}
parameter_overrides: {type: ['object', 'null']}
required: [base_dir, project, bbox, api_keys, run_settings, real, graphfcn_list, metric_list, parameter_overrides]
required: [base_dir, project, bbox, api_keys, graphfcn_list]
79 changes: 63 additions & 16 deletions swmmanywhere/swmmanywhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,64 @@ def swmmanywhere(config: dict):

return metrics

def check_top_level_paths(config: dict):
"""Check the top level paths in the config.
Args:
config (dict): The configuration.
Raises:
FileNotFoundError: If a top level path does not exist.
"""
for key in ['base_dir', 'api_keys']:
if not Path(config[key]).exists():
raise FileNotFoundError(f"{key} not found at {config[key]}")
config[key] = Path(config[key])
return config

def check_address_overrides(config: dict):
"""Check the address overrides in the config.
Args:
config (dict): The configuration.
Raises:
FileNotFoundError: If an address override path does not exist.
"""
overrides = config.get('address_overrides', None)

if not overrides:
return config

for key, path in overrides.items():
if not Path(path).exists():
raise FileNotFoundError(f"{key} not found at {path}")
config['address_overrides'][key] = Path(path)
return config

def check_real_network_paths(config: dict):
"""Check the paths to the real network in the config.
Args:
config (dict): The configuration.
Raises:
FileNotFoundError: If a real network path does not exist.
"""
real = config.get('real', None)

if not real:
return config

for key, path in real.items():
if not isinstance(path, str):
continue
if not Path(path).exists():
raise FileNotFoundError(f"{key} not found at {path}")
config['real'][key] = Path(path)

return config

def load_config(config_path: Path):
"""Load, validate, and convert Paths in a configuration file.
Expand All @@ -125,25 +183,14 @@ def load_config(config_path: Path):
jsonschema.validate(instance = config, schema = schema)

# Check top level paths
for key in ['base_dir', 'api_keys']:
if not Path(config[key]).exists():
raise FileNotFoundError(f"{key} not found at {config[key]}")
config[key] = Path(config[key])

# Check real network paths
for key, path in config['real'].items():
if not isinstance(path, str):
continue
if not Path(path).exists():
raise FileNotFoundError(f"{key} not found at {path}")
config['real'][key] = Path(path)
config = check_top_level_paths(config)

# Check address overrides
for key, path in config.get('address_overrides', {}).items():
if not Path(path).exists():
raise FileNotFoundError(f"{key} not found at {path}")
config['address_overrides'][key] = Path(path)
config = check_address_overrides(config)

# Check real network paths
config = check_real_network_paths(config)

return config


Expand Down
53 changes: 52 additions & 1 deletion tests/test_swmmanywhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import tempfile
from pathlib import Path

import jsonschema
import pytest
import yaml

from swmmanywhere import __version__, swmmanywhere
Expand Down Expand Up @@ -81,4 +83,53 @@ def test_swmmanywhere():

# Run swmmanywhere
swmmanywhere.swmmanywhere(config)


def test_load_config_file_validation():
"""Test the file validation of the config."""
with tempfile.TemporaryDirectory() as temp_dir:
test_data_dir = Path(__file__).parent / 'test_data'
defs_dir = Path(__file__).parent.parent / 'swmmanywhere' / 'defs'
base_dir = Path(temp_dir)

# Test file not found
with pytest.raises(FileNotFoundError) as exc_info:
swmmanywhere.load_config(base_dir / 'test_config.yml')
assert "test_config.yml" in str(exc_info.value)

with (test_data_dir / 'demo_config.yml').open('r') as f:
config = yaml.safe_load(f)

# Correct and avoid filevalidation errors
config['real'] = None

# Fill with unused paths to avoid filevalidation errors
config['base_dir'] = str(defs_dir / 'storm.dat')
config['api_keys'] = str(defs_dir / 'storm.dat')

with open(base_dir / 'test_config.yml', 'w') as f:
yaml.dump(config, f)

config = swmmanywhere.load_config(base_dir / 'test_config.yml')
assert isinstance(config, dict)

def test_load_config_schema_validation():
"""Test the schema validation of the config."""
with tempfile.TemporaryDirectory() as temp_dir:
test_data_dir = Path(__file__).parent / 'test_data'
base_dir = Path(temp_dir)

# Load the config
with (test_data_dir / 'demo_config.yml').open('r') as f:
config = yaml.safe_load(f)

# Make an edit not to schema
config['base_dir'] = 1

with open(base_dir / 'test_config.yml', 'w') as f:
yaml.dump(config, f)

# Test schema validation
with pytest.raises(jsonschema.exceptions.ValidationError) as exc_info:
swmmanywhere.load_config(base_dir / 'test_config.yml')
assert "null" in str(exc_info.value)

0 comments on commit beff056

Please sign in to comment.