Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More logging #116

Merged
merged 20 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion swmmanywhere/geospatial_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from shapely import ops as sops
from shapely.errors import GEOSException
from shapely.strtree import STRtree
from tqdm import tqdm

from swmmanywhere.logging import tqdm

TransformerFromCRS = lru_cache(pyproj.transformer.Transformer.from_crs)

Expand Down
3 changes: 1 addition & 2 deletions swmmanywhere/graph_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
import osmnx as ox
import pandas as pd
import shapely
from tqdm import tqdm

from swmmanywhere import geospatial_utilities as go
from swmmanywhere import parameters
from swmmanywhere.logging import logger
from swmmanywhere.logging import logger, tqdm


def load_graph(fid: Path) -> nx.Graph:
Expand Down
28 changes: 28 additions & 0 deletions swmmanywhere/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,36 @@
import sys

import loguru
from tqdm import tqdm as tqdm_original
barneydobson marked this conversation as resolved.
Show resolved Hide resolved


def tqdm(*args, **kwargs):
barneydobson marked this conversation as resolved.
Show resolved Hide resolved
"""Custom tqdm function.

A custom tqdm function that checks for the verbosity. If verbose, it
returns the actual tqdm progress bar. Otherwise, in the case of a provided
argument, it returns that argument (i.e., iterator), or, in the case of no
arguments an empty object that mocks the functions of a progress bar.
"""
verbose = os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true"

if verbose:
return tqdm_original(*args,
**kwargs)
else:
if not args:
# i.e., a progress bar rather than an iterator
# This is just an empty object that can have 'update' called, as a
# progress bar would be
return type('Obj',
(object,),
{'update': lambda self, _: None,
'close': lambda self, _: None},
)()

iterator = args[0]
return iterator

def dynamic_filter(record):
"""A dynamic filter."""
return os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true"
Expand Down
13 changes: 7 additions & 6 deletions swmmanywhere/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pandas as pd
import yaml

from swmmanywhere.logging import logger
from swmmanywhere.parameters import FilePaths


Expand Down Expand Up @@ -171,10 +172,10 @@ def overwrite_section(data: np.ndarray,
i += 1

example_line = lines[ix + i]
print('example_line {1}: {0}'.format(
example_line.replace('\n', ''), section))
print('note - this line must have at least as many column')
print('entries as all other rows in this section\n')
# print('example_line {1}: {0}'.format(
# example_line.replace('\n', ''), section))
# print('note - this line must have at least as many column')
# print('entries as all other rows in this section\n')
pattern = r'(\s+)'

# Find all matches of the pattern in the input line
Expand All @@ -185,7 +186,7 @@ def overwrite_section(data: np.ndarray,
for x, y in zip(matches, example_line.split())]
if not space_counts:
if data.shape[0] != 0:
print('no template for data?')
logger.warning('no template for data?')
continue

space_counts[-1] -= 1
Expand Down Expand Up @@ -251,7 +252,7 @@ def data_dict_to_inp(data_dict: dict[str, np.ndarray],

# Write the inp file
for key, data in data_dict.items():
print(key)
# print(key)
barneydobson marked this conversation as resolved.
Show resolved Hide resolved
start_section = '[{0}]'.format(key)

overwrite_section(data, start_section, new_input_file)
Expand Down
67 changes: 64 additions & 3 deletions swmmanywhere/swmmanywhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import swmmanywhere.geospatial_utilities as go
from swmmanywhere import parameters, preprocessing
from swmmanywhere.graph_utilities import iterate_graphfcns, load_graph, save_graph
from swmmanywhere.logging import logger
from swmmanywhere.logging import logger, tqdm
from swmmanywhere.metric_utilities import iterate_metrics
from swmmanywhere.post_processing import synthetic_write

Expand Down Expand Up @@ -42,9 +42,19 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]:
config.get('model_number',None)
)

logger.info(f"Project structure created at {addresses.base_dir}")
logger.info(f"Project name: {config['project']}")
logger.info(f"Bounding box: {config['bbox']}, number: {addresses.bbox_number}")
logger.info(f"Model number: {addresses.model_number}")

for key, val in config.get('address_overrides', {}).items():
logger.info(f"Setting {key} to {val}")
setattr(addresses, key, val)

# Save config file
if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true":
save_config(config, addresses.model / 'config.yml')

# Run downloads
logger.info("Running downloads.")
api_keys = yaml.safe_load(config['api_keys'].open('r'))
Expand All @@ -65,6 +75,7 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]:
params = parameters.get_full_parameters()
for category, overrides in config.get('parameter_overrides', {}).items():
for key, val in overrides.items():
logger.info(f"Setting {category} {key} to {val}")
setattr(params[category], key, val)

# Iterate the graph functions
Expand All @@ -89,8 +100,8 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]:
logger.info("Running the synthetic model.")
synthetic_results = run(addresses.inp,
**config['run_settings'])
logger.info("Writing synthetic results.")
if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true":
logger.info("Writing synthetic results.")
synthetic_results.to_parquet(addresses.model /\
f'results.{addresses.extension}')

Expand Down Expand Up @@ -238,6 +249,46 @@ def check_starting_graph(config: dict):

return config

def check_parameter_overrides(config: dict):
"""Check the parameter overrides in the config.

Args:
config (dict): The configuration.

Raises:
ValueError: If a parameter override is not in the parameters
dictionary.
"""
params = parameters.get_full_parameters()
for category, overrides in config.get('parameter_overrides',{}).items():
if category not in params:
raise ValueError(f"""{category} not a category of parameter. Must
be one of {params.keys()}.""")

for key, val in overrides.items():
# Check that the parameter is available
if key not in params[category].model_json_schema()['properties']:
barneydobson marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"{key} not found in {category}.")

return config

# Define a custom Dumper class to write Path
class _CustomDumper(yaml.SafeDumper):
barneydobson marked this conversation as resolved.
Show resolved Hide resolved
def represent_data(self, data):
if isinstance(data, Path):
return self.represent_scalar('tag:yaml.org,2002:str', str(data))
return super().represent_data(data)

def save_config(config: dict, config_path: Path):
"""Save the configuration to a file.

Args:
config (dict): The configuration.
config_path (Path): The path to save the configuration.
"""
with config_path.open('w') as f:
yaml.dump(config, f, Dumper=_CustomDumper, default_flow_style=False)

def load_config(config_path: Path, validation: bool = True):
"""Load, validate, and convert Paths in a configuration file.

Expand Down Expand Up @@ -278,6 +329,9 @@ def load_config(config_path: Path, validation: bool = True):
# Check starting graph
config = check_starting_graph(config)

# Check parameter overrides
config = check_parameter_overrides(config)

return config


Expand Down Expand Up @@ -328,9 +382,16 @@ def run(model: Path,
results = []
t_ = sim.current_time
ind = 0
while ((sim.current_time - t_).total_seconds() <= duration) & \
logger.info(f"Starting simulation for: {model}")

progress_bar = tqdm(total=duration)
offset = 0
while (offset <= duration) & \
(sim.current_time < sim.end_time) & (not sim._terminate_request):

progress_bar.update((sim.current_time - t_).total_seconds() - offset)
offset = (sim.current_time - t_).total_seconds()

ind+=1

# Iterate the main model timestep
Expand Down
62 changes: 60 additions & 2 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest.mock import patch

from swmmanywhere.logging import logger
from tqdm import tqdm as tqdm_original

from swmmanywhere.logging import logger, tqdm


def test_logger():
Expand Down Expand Up @@ -71,4 +74,59 @@ def test_logger_again():
assert temp_file.read() != b""
logger.remove()
fid.unlink()
os.environ["SWMMANYWHERE_VERBOSE"] = "false"
os.environ["SWMMANYWHERE_VERBOSE"] = "false"

def test_tqdm():
"""Test custom tqdm with true verbose."""
# Set SWMMANYWHERE_VERBOSE to True
os.environ["SWMMANYWHERE_VERBOSE"] = "true"

# Create a mock iterator
mock_iterator = iter(range(10))

# Patch the original tqdm function
with patch("swmmanywhere.logging.tqdm_original",
wraps=tqdm_original) as mock_tqdm:
# Call the custom tqdm function
result = [i for i in tqdm(mock_iterator)]

# Check if the original tqdm was called
mock_tqdm.assert_called()

# Check if the progress_bar is the same as the mocked tqdm
assert result == list(range(10))

def test_tqdm_not_verbose():
"""Test custom tqdm with false verbose."""
# Set SWMMANYWHERE_VERBOSE to False
os.environ["SWMMANYWHERE_VERBOSE"] = "false"

# Create a mock iterator
mock_iterator = iter(range(10))
with patch("swmmanywhere.logging.tqdm_original") as mock_tqdm:
# Call the custom tqdm function
result = [i for i in tqdm(mock_iterator)]

mock_tqdm.assert_not_called()

# Check if the progress_bar is the same as the mock_iterator
assert result == list(range(10))

def test_tqdm_verbose_unset():
"""Test custom tqdm with no verbose."""
# Unset SWMMANYWHERE_VERBOSE
os.environ["SWMMANYWHERE_VERBOSE"] = "true"
if "SWMMANYWHERE_VERBOSE" in os.environ:
del os.environ["SWMMANYWHERE_VERBOSE"]

# Create a mock iterator
mock_iterator = iter(range(10))

with patch("swmmanywhere.logging.tqdm_original") as mock_tqdm:
# Call the custom tqdm function
result = [i for i in tqdm(mock_iterator)]

mock_tqdm.assert_not_called()

# Check if the progress_bar is the same as the mock_iterator
assert result == list(range(10))
39 changes: 38 additions & 1 deletion tests/test_swmmanywhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,43 @@ def test_check_parameters_to_sample():
# Test parameter validation
with pytest.raises(ValueError) as exc_info:
swmmanywhere.load_config(base_dir / 'test_config.yml')
assert "not_a_parameter" in str(exc_info.value)
assert "not_a_parameter" in str(exc_info.value)

# Test parameter_overrides invalid category
config['parameter_overrides'] = {'fake_category' : {'fake_parameter' : 0}}
with pytest.raises(ValueError) as exc_info:
swmmanywhere.check_parameter_overrides(config)
assert "fake_category not a category" in str(exc_info.value)

# Test parameter_overrides invalid parameter
config['parameter_overrides'] = {'hydraulic_design' : {'fake_parameter' : 0}}
with pytest.raises(ValueError) as exc_info:
swmmanywhere.check_parameter_overrides(config)
assert "fake_parameter not found" in str(exc_info.value)

# Test parameter_overrides valid
config['parameter_overrides'] = {'hydraulic_design' : {'min_v' : 1.0}}
_ = swmmanywhere.check_parameter_overrides(config)


def test_save_config():
"""Test the save_config function."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
test_data_dir = Path(__file__).parent / 'test_data'
defs_dir = Path(__file__).parent.parent / 'swmmanywhere' / 'defs'

with (test_data_dir / 'demo_config.yml').open('r') as f:
config = yaml.safe_load(f)

# Correct and avoid filevalidation errors
config['real'] = None

# Fill with unused paths to avoid filevalidation errors
config['base_dir'] = str(defs_dir / 'storm.dat')
config['api_keys'] = str(defs_dir / 'storm.dat')

swmmanywhere.save_config(config, temp_dir / 'test.yml')

# Reload to check OK
config = swmmanywhere.load_config(temp_dir / 'test.yml')