From eb1995594364c2c292523a58ead24c72350c6bd5 Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 13:28:56 +0100 Subject: [PATCH 01/15] make and test save_config --- swmmanywhere/swmmanywhere.py | 22 ++++++++++++++++++++++ tests/test_swmmanywhere.py | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index eb84d837..41aa7b40 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -45,6 +45,10 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: for key, val in config.get('address_overrides', {}).items(): setattr(addresses, key, val) + # Save config file + if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true": + save_config(config, addresses.model / 'config.yml') + # Run downloads logger.info("Running downloads.") api_keys = yaml.safe_load(config['api_keys'].open('r')) @@ -65,6 +69,7 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: params = parameters.get_full_parameters() for category, overrides in config.get('parameter_overrides', {}).items(): for key, val in overrides.items(): + logger.info(f"Setting {category} {key} to {val}") setattr(params[category], key, val) # Iterate the graph functions @@ -238,6 +243,23 @@ def check_starting_graph(config: dict): return config +# Define a custom Dumper class to write Path +class _CustomDumper(yaml.SafeDumper): + def represent_data(self, data): + if isinstance(data, Path): + return self.represent_scalar('tag:yaml.org,2002:str', str(data)) + return super().represent_data(data) + +def save_config(config: dict, config_path: Path): + """Save the configuration to a file. + + Args: + config (dict): The configuration. + config_path (Path): The path to save the configuration. + """ + with config_path.open('w') as f: + yaml.dump(config, f, Dumper=_CustomDumper, default_flow_style=False) + def load_config(config_path: Path, validation: bool = True): """Load, validate, and convert Paths in a configuration file. diff --git a/tests/test_swmmanywhere.py b/tests/test_swmmanywhere.py index 7dedd78c..4f8f6a7d 100644 --- a/tests/test_swmmanywhere.py +++ b/tests/test_swmmanywhere.py @@ -181,4 +181,24 @@ def test_check_parameters_to_sample(): swmmanywhere.load_config(base_dir / 'test_config.yml') assert "not_a_parameter" in str(exc_info.value) +def test_save_config(): + """Test the save_config function.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = Path(temp_dir) + test_data_dir = Path(__file__).parent / 'test_data' + defs_dir = Path(__file__).parent.parent / 'swmmanywhere' / 'defs' + + with (test_data_dir / 'demo_config.yml').open('r') as f: + config = yaml.safe_load(f) + + # Correct and avoid filevalidation errors + config['real'] = None + + # Fill with unused paths to avoid filevalidation errors + config['base_dir'] = str(defs_dir / 'storm.dat') + config['api_keys'] = str(defs_dir / 'storm.dat') + + swmmanywhere.save_config(config, temp_dir / 'test.yml') + # Reload to check OK + config = swmmanywhere.load_config(temp_dir / 'test.yml') From dec462e798ba2f5196a75bbf17806ade426a0920 Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 13:33:20 +0100 Subject: [PATCH 02/15] Update swmmanywhere.py logger already checks verbosity --- swmmanywhere/swmmanywhere.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index 41aa7b40..3030b413 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -94,8 +94,8 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: logger.info("Running the synthetic model.") synthetic_results = run(addresses.inp, **config['run_settings']) + logger.info("Writing synthetic results.") if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true": - logger.info("Writing synthetic results.") synthetic_results.to_parquet(addresses.model /\ f'results.{addresses.extension}') From 5d11fbb7e9811f699b22fb2e4c06cdd82354912d Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 14:21:08 +0100 Subject: [PATCH 03/15] control tqdm with verbosity --- swmmanywhere/geospatial_utilities.py | 3 +- swmmanywhere/graph_utilities.py | 3 +- swmmanywhere/logging.py | 16 +++++++ tests/test_logging.py | 62 +++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 5 deletions(-) diff --git a/swmmanywhere/geospatial_utilities.py b/swmmanywhere/geospatial_utilities.py index 712e775a..9cd0e8b4 100644 --- a/swmmanywhere/geospatial_utilities.py +++ b/swmmanywhere/geospatial_utilities.py @@ -29,7 +29,8 @@ from shapely import ops as sops from shapely.errors import GEOSException from shapely.strtree import STRtree -from tqdm import tqdm + +from swmmanywhere.logging import tqdm TransformerFromCRS = lru_cache(pyproj.transformer.Transformer.from_crs) diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index 9a121d0f..5fdaad66 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -21,11 +21,10 @@ import osmnx as ox import pandas as pd import shapely -from tqdm import tqdm from swmmanywhere import geospatial_utilities as go from swmmanywhere import parameters -from swmmanywhere.logging import logger +from swmmanywhere.logging import logger, tqdm def load_graph(fid: Path) -> nx.Graph: diff --git a/swmmanywhere/logging.py b/swmmanywhere/logging.py index 1906886e..579e3789 100644 --- a/swmmanywhere/logging.py +++ b/swmmanywhere/logging.py @@ -16,8 +16,24 @@ import sys import loguru +from tqdm import tqdm as tqdm_original +def tqdm(*args, **kwargs): + """Custom tqdm function. + + A custom tqdm function that checks for the verbosity. If verbose, it + returns the actual tqdm progress bar. Otherwise, it returns a simple + iterator without the progress bar. + """ + verbose = os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" + + if verbose: + return tqdm_original(*args, **kwargs) + else: + iterator = args[0] + return iterator + def dynamic_filter(record): """A dynamic filter.""" return os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" diff --git a/tests/test_logging.py b/tests/test_logging.py index fb1f212c..9c2f23f4 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -8,8 +8,11 @@ import os from pathlib import Path from tempfile import NamedTemporaryFile +from unittest.mock import patch -from swmmanywhere.logging import logger +from tqdm import tqdm as tqdm_original + +from swmmanywhere.logging import logger, tqdm def test_logger(): @@ -71,4 +74,59 @@ def test_logger_again(): assert temp_file.read() != b"" logger.remove() fid.unlink() - os.environ["SWMMANYWHERE_VERBOSE"] = "false" \ No newline at end of file + os.environ["SWMMANYWHERE_VERBOSE"] = "false" + +def test_tqdm(): + """Test custom tqdm with true verbose.""" + # Set SWMMANYWHERE_VERBOSE to True + os.environ["SWMMANYWHERE_VERBOSE"] = "true" + + # Create a mock iterator + mock_iterator = iter(range(10)) + + # Patch the original tqdm function + with patch("swmmanywhere.logging.tqdm_original", + wraps=tqdm_original) as mock_tqdm: + # Call the custom tqdm function + result = [i for i in tqdm(mock_iterator)] + + # Check if the original tqdm was called + mock_tqdm.assert_called() + + # Check if the progress_bar is the same as the mocked tqdm + assert result == list(range(10)) + +def test_tqdm_not_verbose(): + """Test custom tqdm with false verbose.""" + # Set SWMMANYWHERE_VERBOSE to False + os.environ["SWMMANYWHERE_VERBOSE"] = "false" + + # Create a mock iterator + mock_iterator = iter(range(10)) + with patch("swmmanywhere.logging.tqdm_original") as mock_tqdm: + # Call the custom tqdm function + result = [i for i in tqdm(mock_iterator)] + + mock_tqdm.assert_not_called() + + # Check if the progress_bar is the same as the mock_iterator + assert result == list(range(10)) + +def test_tqdm_verbose_unset(): + """Test custom tqdm with no verbose.""" + # Unset SWMMANYWHERE_VERBOSE + os.environ["SWMMANYWHERE_VERBOSE"] = "true" + if "SWMMANYWHERE_VERBOSE" in os.environ: + del os.environ["SWMMANYWHERE_VERBOSE"] + + # Create a mock iterator + mock_iterator = iter(range(10)) + + with patch("swmmanywhere.logging.tqdm_original") as mock_tqdm: + # Call the custom tqdm function + result = [i for i in tqdm(mock_iterator)] + + mock_tqdm.assert_not_called() + + # Check if the progress_bar is the same as the mock_iterator + assert result == list(range(10)) \ No newline at end of file From 7b05e24d86016023a5dca59faf74caf4ec6b9be6 Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 14:21:23 +0100 Subject: [PATCH 04/15] Update swmmanywhere.py print more info --- swmmanywhere/swmmanywhere.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index 3030b413..c3e48ab0 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -42,7 +42,13 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: config.get('model_number',None) ) + logger.info(f"Project structure created at {addresses.base_dir}") + logger.info(f"Project name: {config['project']}") + logger.info(f"Bounding box: {config['bbox']}, number: {addresses.bbox_number}") + logger.info(f"Model number: {addresses.model_number}") + for key, val in config.get('address_overrides', {}).items(): + logger.info(f"Setting {key} to {val}") setattr(addresses, key, val) # Save config file @@ -350,6 +356,7 @@ def run(model: Path, results = [] t_ = sim.current_time ind = 0 + logger.info(f"Starting simulation for: {model}") while ((sim.current_time - t_).total_seconds() <= duration) & \ (sim.current_time < sim.end_time) & (not sim._terminate_request): From 2d73ca24f5299daa110e00d03ed30903916f8acb Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 14:23:35 +0100 Subject: [PATCH 05/15] Update post_processing.py --- swmmanywhere/post_processing.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/swmmanywhere/post_processing.py b/swmmanywhere/post_processing.py index 8d1c9e51..290a1e6c 100644 --- a/swmmanywhere/post_processing.py +++ b/swmmanywhere/post_processing.py @@ -15,6 +15,7 @@ import pandas as pd import yaml +from swmmanywhere.logging import logger from swmmanywhere.parameters import FilePaths @@ -171,10 +172,10 @@ def overwrite_section(data: np.ndarray, i += 1 example_line = lines[ix + i] - print('example_line {1}: {0}'.format( - example_line.replace('\n', ''), section)) - print('note - this line must have at least as many column') - print('entries as all other rows in this section\n') + # print('example_line {1}: {0}'.format( + # example_line.replace('\n', ''), section)) + # print('note - this line must have at least as many column') + # print('entries as all other rows in this section\n') pattern = r'(\s+)' # Find all matches of the pattern in the input line @@ -185,7 +186,7 @@ def overwrite_section(data: np.ndarray, for x, y in zip(matches, example_line.split())] if not space_counts: if data.shape[0] != 0: - print('no template for data?') + logger.warning('no template for data?') continue space_counts[-1] -= 1 @@ -251,7 +252,7 @@ def data_dict_to_inp(data_dict: dict[str, np.ndarray], # Write the inp file for key, data in data_dict.items(): - print(key) + # print(key) start_section = '[{0}]'.format(key) overwrite_section(data, start_section, new_input_file) From 02391fdd5f854bc729565ad84ec92ce6125d48f9 Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 14:53:28 +0100 Subject: [PATCH 06/15] Enable progress bar for run --- swmmanywhere/logging.py | 13 ++++++++++++- swmmanywhere/swmmanywhere.py | 10 ++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/swmmanywhere/logging.py b/swmmanywhere/logging.py index 579e3789..3c6c42eb 100644 --- a/swmmanywhere/logging.py +++ b/swmmanywhere/logging.py @@ -29,8 +29,19 @@ def tqdm(*args, **kwargs): verbose = os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" if verbose: - return tqdm_original(*args, **kwargs) + return tqdm_original(*args, + **kwargs) else: + if not args: + # i.e., a progress bar rather than an iterator + # This is just an empty object that can have 'update' called, as a + # progress bar would be + return type('Obj', + (object,), + {'update': lambda self, _: None, + 'close': lambda self, _: None}, + )() + iterator = args[0] return iterator diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index c3e48ab0..dd72da66 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -13,7 +13,7 @@ import swmmanywhere.geospatial_utilities as go from swmmanywhere import parameters, preprocessing from swmmanywhere.graph_utilities import iterate_graphfcns, load_graph, save_graph -from swmmanywhere.logging import logger +from swmmanywhere.logging import logger, tqdm from swmmanywhere.metric_utilities import iterate_metrics from swmmanywhere.post_processing import synthetic_write @@ -357,9 +357,15 @@ def run(model: Path, t_ = sim.current_time ind = 0 logger.info(f"Starting simulation for: {model}") - while ((sim.current_time - t_).total_seconds() <= duration) & \ + + progress_bar = tqdm(total=duration) + offset = 0 + while (offset <= duration) & \ (sim.current_time < sim.end_time) & (not sim._terminate_request): + progress_bar.update((sim.current_time - t_).total_seconds() - offset) + offset = (sim.current_time - t_).total_seconds() + ind+=1 # Iterate the main model timestep From c14be5776c2d3770cc3c1ef41639bbc23d0af6b8 Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 14:57:27 +0100 Subject: [PATCH 07/15] Update logging.py improve docstring --- swmmanywhere/logging.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/swmmanywhere/logging.py b/swmmanywhere/logging.py index 3c6c42eb..3db68b4d 100644 --- a/swmmanywhere/logging.py +++ b/swmmanywhere/logging.py @@ -23,8 +23,9 @@ def tqdm(*args, **kwargs): """Custom tqdm function. A custom tqdm function that checks for the verbosity. If verbose, it - returns the actual tqdm progress bar. Otherwise, it returns a simple - iterator without the progress bar. + returns the actual tqdm progress bar. Otherwise, in the case of a provided + argument, it returns that argument (i.e., iterator), or, in the case of no + arguments an empty object that mocks the functions of a progress bar. """ verbose = os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" From 8cf77c86953c6af57a020d6b3c2cdb7e54a09bfe Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 15:00:25 +0100 Subject: [PATCH 08/15] Update swmmanywhere.py check_parameter_overrides --- swmmanywhere/swmmanywhere.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index dd72da66..a074b616 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -249,6 +249,29 @@ def check_starting_graph(config: dict): return config +def check_parameter_overrides(config: dict): + """Check the parameter overrides in the config. + + Args: + config (dict): The configuration. + + Raises: + ValueError: If a parameter override is not in the parameters + dictionary. + """ + params = parameters.get_full_parameters() + for category, overrides in config.get('parameter_overrides',{}).items(): + if category not in params: + raise ValueError(f"""{category} not a category of parameter. Must + be one of {params.keys()}.""") + + for key, val in overrides.items(): + # Check that the parameter is available + if key not in params[category].model_json_schema()['properties']: + raise ValueError(f"{key} not found in {category}.") + + return config + # Define a custom Dumper class to write Path class _CustomDumper(yaml.SafeDumper): def represent_data(self, data): @@ -306,6 +329,9 @@ def load_config(config_path: Path, validation: bool = True): # Check starting graph config = check_starting_graph(config) + # Check parameter overrides + config = check_parameter_overrides(config) + return config From 9ab6ecf3e3cba7fcc5b2cdf2da2fe34ffa28d46b Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 23 May 2024 15:10:08 +0100 Subject: [PATCH 09/15] Update test_swmmanywhere.py --- tests/test_swmmanywhere.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/test_swmmanywhere.py b/tests/test_swmmanywhere.py index 4f8f6a7d..57aa787a 100644 --- a/tests/test_swmmanywhere.py +++ b/tests/test_swmmanywhere.py @@ -179,7 +179,24 @@ def test_check_parameters_to_sample(): # Test parameter validation with pytest.raises(ValueError) as exc_info: swmmanywhere.load_config(base_dir / 'test_config.yml') - assert "not_a_parameter" in str(exc_info.value) + assert "not_a_parameter" in str(exc_info.value) + + # Test parameter_overrides invalid category + config['parameter_overrides'] = {'fake_category' : {'fake_parameter' : 0}} + with pytest.raises(ValueError) as exc_info: + swmmanywhere.check_parameter_overrides(config) + assert "fake_category not a category" in str(exc_info.value) + + # Test parameter_overrides invalid parameter + config['parameter_overrides'] = {'hydraulic_design' : {'fake_parameter' : 0}} + with pytest.raises(ValueError) as exc_info: + swmmanywhere.check_parameter_overrides(config) + assert "fake_parameter not found" in str(exc_info.value) + + # Test parameter_overrides valid + config['parameter_overrides'] = {'hydraulic_design' : {'min_v' : 1.0}} + _ = swmmanywhere.check_parameter_overrides(config) + def test_save_config(): """Test the save_config function.""" @@ -187,7 +204,7 @@ def test_save_config(): temp_dir = Path(temp_dir) test_data_dir = Path(__file__).parent / 'test_data' defs_dir = Path(__file__).parent.parent / 'swmmanywhere' / 'defs' - + with (test_data_dir / 'demo_config.yml').open('r') as f: config = yaml.safe_load(f) From b620e7ce3103f5196f8af480df9a3273d26ea2d7 Mon Sep 17 00:00:00 2001 From: barneydobson Date: Fri, 24 May 2024 09:03:08 +0100 Subject: [PATCH 10/15] Update swmmanywhere.py don't check categories for every override item --- swmmanywhere/swmmanywhere.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index a074b616..e5485994 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -265,9 +265,12 @@ def check_parameter_overrides(config: dict): raise ValueError(f"""{category} not a category of parameter. Must be one of {params.keys()}.""") + # Get the available properties for a category + cat_properties = params[category].model_json_schema()['properties'] + for key, val in overrides.items(): # Check that the parameter is available - if key not in params[category].model_json_schema()['properties']: + if key not in cat_properties: raise ValueError(f"{key} not found in {category}.") return config From f92e8e5f2ace2e5419dbe59c0bd63ddc4824369c Mon Sep 17 00:00:00 2001 From: barneydobson Date: Fri, 24 May 2024 09:03:28 +0100 Subject: [PATCH 11/15] Update logging.py use `auto` --- swmmanywhere/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swmmanywhere/logging.py b/swmmanywhere/logging.py index 3db68b4d..b51d0ed0 100644 --- a/swmmanywhere/logging.py +++ b/swmmanywhere/logging.py @@ -16,7 +16,7 @@ import sys import loguru -from tqdm import tqdm as tqdm_original +from tqdm.auto import tqdm as tqdm_original def tqdm(*args, **kwargs): From 8deb7c36aaa604c7ca879013bfc89e9c02c94aa2 Mon Sep 17 00:00:00 2001 From: barneydobson Date: Fri, 24 May 2024 10:09:57 +0100 Subject: [PATCH 12/15] Revise - Add utilities for yaml - Remove custom tqdm - Add verbosity variable --- swmmanywhere/geospatial_utilities.py | 10 +++-- swmmanywhere/graph_utilities.py | 12 +++--- swmmanywhere/logging.py | 32 ++------------- swmmanywhere/prepare_data.py | 5 +-- swmmanywhere/swmmanywhere.py | 37 +++++++---------- swmmanywhere/utilities.py | 49 ++++++++++++++++++++++ tests/test_logging.py | 61 ++++------------------------ 7 files changed, 89 insertions(+), 117 deletions(-) create mode 100644 swmmanywhere/utilities.py diff --git a/swmmanywhere/geospatial_utilities.py b/swmmanywhere/geospatial_utilities.py index 9cd0e8b4..f7e10c81 100644 --- a/swmmanywhere/geospatial_utilities.py +++ b/swmmanywhere/geospatial_utilities.py @@ -29,8 +29,9 @@ from shapely import ops as sops from shapely.errors import GEOSException from shapely.strtree import STRtree +from tqdm.auto import tqdm -from swmmanywhere.logging import tqdm +from swmmanywhere.logging import verbose TransformerFromCRS = lru_cache(pyproj.transformer.Transformer.from_crs) @@ -417,7 +418,9 @@ def delineate_catchment(grid: pysheds.sgrid.sGrid, """ polys = [] # Iterate over the nodes in the graph - for id, data in tqdm(G.nodes(data=True), total=len(G.nodes)): + for id, data in tqdm(G.nodes(data=True), + total=len(G.nodes), + disable = not verbose()): # Snap the node to the nearest grid cell x, y = data['x'], data['y'] grid_ = deepcopy(grid) @@ -476,7 +479,8 @@ def remove_intersections(polys: gpd.GeoDataFrame) -> gpd.GeoDataFrame: # with the smallest area polygon minimal_geom = result_polygons.iloc[0]['geometry'] for idx, row in tqdm(result_polygons.iloc[1:].iterrows(), - total=result_polygons.shape[0] - 1): + total=result_polygons.shape[0] - 1, + disable = not verbose()): # Trim the polygon by the combined geometry result_polygons.at[idx, 'geometry'] = row['geometry'].difference( diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index 5fdaad66..2fa36223 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -6,7 +6,6 @@ from __future__ import annotations import json -import os import tempfile from abc import ABC, abstractmethod from collections import defaultdict @@ -21,10 +20,11 @@ import osmnx as ox import pandas as pd import shapely +from tqdm.auto import tqdm from swmmanywhere import geospatial_utilities as go from swmmanywhere import parameters -from swmmanywhere.logging import logger, tqdm +from swmmanywhere.logging import logger, verbose def load_graph(fid: Path) -> nx.Graph: @@ -178,11 +178,10 @@ def iterate_graphfcns(G: nx.Graph, not_exists = [g for g in graphfcn_list if g not in graphfcns] if not_exists: raise ValueError(f"Graphfcns are not registered:\n{', '.join(not_exists)}") - verbose = os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" for function in graphfcn_list: G = graphfcns[function](G, addresses = addresses, **params) logger.info(f"graphfcn: {function} completed.") - if verbose: + if verbose(): save_graph(G, addresses.model / f"{function}_graph.json") return G @@ -534,7 +533,7 @@ def __call__(self, G: nx.Graph, # Derive subs_gdf = go.derive_subcatchments(G,temp_fid) - if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true": + if verbose(): subs_gdf.to_file(addresses.subcatchments, driver='GeoJSON') # Calculate runoff coefficient (RC) @@ -1147,7 +1146,8 @@ def __call__(self, chamber_floor = {} edge_diams: dict[tuple[Hashable,Hashable,int],float] = {} # Iterate over nodes in topological order - for node in tqdm(topological_order): + for node in tqdm(topological_order, + disable = not verbose()): # Check if there's any nodes upstream, if not set the depth to min_depth if len(nx.ancestors(G,node)) == 0: chamber_floor[node] = surface_elevations[node] - \ diff --git a/swmmanywhere/logging.py b/swmmanywhere/logging.py index b51d0ed0..f3b332cf 100644 --- a/swmmanywhere/logging.py +++ b/swmmanywhere/logging.py @@ -16,39 +16,15 @@ import sys import loguru -from tqdm.auto import tqdm as tqdm_original -def tqdm(*args, **kwargs): - """Custom tqdm function. - - A custom tqdm function that checks for the verbosity. If verbose, it - returns the actual tqdm progress bar. Otherwise, in the case of a provided - argument, it returns that argument (i.e., iterator), or, in the case of no - arguments an empty object that mocks the functions of a progress bar. - """ - verbose = os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" +def verbose() -> bool: + """Get the verbosity.""" + return os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" - if verbose: - return tqdm_original(*args, - **kwargs) - else: - if not args: - # i.e., a progress bar rather than an iterator - # This is just an empty object that can have 'update' called, as a - # progress bar would be - return type('Obj', - (object,), - {'update': lambda self, _: None, - 'close': lambda self, _: None}, - )() - - iterator = args[0] - return iterator - def dynamic_filter(record): """A dynamic filter.""" - return os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" + return verbose() def get_logger() -> loguru.logger: """Get a logger.""" diff --git a/swmmanywhere/prepare_data.py b/swmmanywhere/prepare_data.py index e87b4703..e7118fc6 100644 --- a/swmmanywhere/prepare_data.py +++ b/swmmanywhere/prepare_data.py @@ -14,10 +14,10 @@ import pandas as pd import requests import xarray as xr -import yaml from geopy.geocoders import Nominatim from swmmanywhere.logging import logger +from swmmanywhere.utilities import yaml_load def get_country(x: float, @@ -46,8 +46,7 @@ def get_country(x: float, geolocator = Nominatim(user_agent="get_iso") # Load ISO code mapping from YAML file - with iso_path.open("r") as file: - data = yaml.safe_load(file) + data = yaml_load(iso_path.open("r")) # Get country ISO code from coordinates location = geolocator.reverse(f"{y}, {x}", exactly_one=True) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index e5485994..c4aa31d1 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -1,21 +1,21 @@ """The main SWMManywhere module to generate and run a synthetic network.""" from __future__ import annotations -import os from pathlib import Path import geopandas as gpd import jsonschema import pandas as pd import pyswmm -import yaml +from tqdm.auto import tqdm import swmmanywhere.geospatial_utilities as go from swmmanywhere import parameters, preprocessing from swmmanywhere.graph_utilities import iterate_graphfcns, load_graph, save_graph -from swmmanywhere.logging import logger, tqdm +from swmmanywhere.logging import logger, verbose from swmmanywhere.metric_utilities import iterate_metrics from swmmanywhere.post_processing import synthetic_write +from swmmanywhere.utilities import yaml_dump, yaml_load def swmmanywhere(config: dict) -> tuple[Path, dict | None]: @@ -52,12 +52,12 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: setattr(addresses, key, val) # Save config file - if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true": + if verbose(): save_config(config, addresses.model / 'config.yml') # Run downloads logger.info("Running downloads.") - api_keys = yaml.safe_load(config['api_keys'].open('r')) + api_keys = yaml_load(config['api_keys'].open('r')) preprocessing.run_downloads(config['bbox'], addresses, api_keys @@ -101,7 +101,7 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: synthetic_results = run(addresses.inp, **config['run_settings']) logger.info("Writing synthetic results.") - if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true": + if verbose(): synthetic_results.to_parquet(addresses.model /\ f'results.{addresses.extension}') @@ -114,7 +114,7 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: logger.info("Running the real model.") real_results = run(config['real']['inp'], **config['run_settings']) - if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true": + if verbose(): real_results.to_parquet(config['real']['inp'].parent /\ f'real_results.{addresses.extension}') else: @@ -274,13 +274,6 @@ def check_parameter_overrides(config: dict): raise ValueError(f"{key} not found in {category}.") return config - -# Define a custom Dumper class to write Path -class _CustomDumper(yaml.SafeDumper): - def represent_data(self, data): - if isinstance(data, Path): - return self.represent_scalar('tag:yaml.org,2002:str', str(data)) - return super().represent_data(data) def save_config(config: dict, config_path: Path): """Save the configuration to a file. @@ -289,8 +282,7 @@ def save_config(config: dict, config_path: Path): config (dict): The configuration. config_path (Path): The path to save the configuration. """ - with config_path.open('w') as f: - yaml.dump(config, f, Dumper=_CustomDumper, default_flow_style=False) + yaml_dump(config, config_path.open('w')) def load_config(config_path: Path, validation: bool = True): """Load, validate, and convert Paths in a configuration file. @@ -304,12 +296,10 @@ def load_config(config_path: Path, validation: bool = True): """ # Load the schema schema_fid = Path(__file__).parent / 'defs' / 'schema.yml' - with schema_fid.open('r') as file: - schema = yaml.safe_load(file) + schema = yaml_load(schema_fid.open('r')) - with config_path.open('r') as f: - # Load the config - config = yaml.safe_load(f) + # Load the config + config = yaml_load(config_path.open('r')) if not validation: return config @@ -386,8 +376,9 @@ def run(model: Path, t_ = sim.current_time ind = 0 logger.info(f"Starting simulation for: {model}") - - progress_bar = tqdm(total=duration) + + progress_bar = tqdm(total=duration, disable = not verbose()) + offset = 0 while (offset <= duration) & \ (sim.current_time < sim.end_time) & (not sim._terminate_request): diff --git a/swmmanywhere/utilities.py b/swmmanywhere/utilities.py new file mode 100644 index 00000000..1c832bab --- /dev/null +++ b/swmmanywhere/utilities.py @@ -0,0 +1,49 @@ +"""Utilities for YAML save/load. + +Author: cheginit +""" +from __future__ import annotations + +import functools +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import yaml + +if TYPE_CHECKING: + SafeDumper = yaml.SafeDumper + from yaml.nodes import Node +else: + Node = Any + SafeDumper = getattr(yaml, "CSafeDumper", yaml.SafeDumper) + +yaml_load = functools.partial(yaml.load, + Loader=getattr(yaml, "CSafeLoader", + yaml.SafeLoader)) + +class PathDumper(SafeDumper): + """A dumper that can represent pathlib.Path objects as strings.""" + def represent_data(self, data: Any)-> Node: + """Represent data.""" + if isinstance(data, Path): + return self.represent_scalar('tag:yaml.org,2002:str', str(data)) + return super().represent_data(data) + +def yaml_dump(o: Any, + stream: Any = None, + **kwargs: Any) -> str: + """Dump YAML. + + Notes: + ----- + When python/mypy#1484 is solved, this can be ``functools.partial`` + """ + return yaml.dump( + o, + Dumper=PathDumper, + stream=stream, + default_flow_style=False, + indent=2, + sort_keys=False, + **kwargs, + ) \ No newline at end of file diff --git a/tests/test_logging.py b/tests/test_logging.py index 9c2f23f4..092beccd 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -8,11 +8,8 @@ import os from pathlib import Path from tempfile import NamedTemporaryFile -from unittest.mock import patch -from tqdm import tqdm as tqdm_original - -from swmmanywhere.logging import logger, tqdm +from swmmanywhere.logging import logger, verbose def test_logger(): @@ -76,57 +73,13 @@ def test_logger_again(): fid.unlink() os.environ["SWMMANYWHERE_VERBOSE"] = "false" -def test_tqdm(): - """Test custom tqdm with true verbose.""" - # Set SWMMANYWHERE_VERBOSE to True +def test_verbose(): + """Test the verbose function.""" os.environ["SWMMANYWHERE_VERBOSE"] = "true" + assert verbose() - # Create a mock iterator - mock_iterator = iter(range(10)) - - # Patch the original tqdm function - with patch("swmmanywhere.logging.tqdm_original", - wraps=tqdm_original) as mock_tqdm: - # Call the custom tqdm function - result = [i for i in tqdm(mock_iterator)] - - # Check if the original tqdm was called - mock_tqdm.assert_called() - - # Check if the progress_bar is the same as the mocked tqdm - assert result == list(range(10)) - -def test_tqdm_not_verbose(): - """Test custom tqdm with false verbose.""" - # Set SWMMANYWHERE_VERBOSE to False os.environ["SWMMANYWHERE_VERBOSE"] = "false" + assert not verbose() - # Create a mock iterator - mock_iterator = iter(range(10)) - with patch("swmmanywhere.logging.tqdm_original") as mock_tqdm: - # Call the custom tqdm function - result = [i for i in tqdm(mock_iterator)] - - mock_tqdm.assert_not_called() - - # Check if the progress_bar is the same as the mock_iterator - assert result == list(range(10)) - -def test_tqdm_verbose_unset(): - """Test custom tqdm with no verbose.""" - # Unset SWMMANYWHERE_VERBOSE - os.environ["SWMMANYWHERE_VERBOSE"] = "true" - if "SWMMANYWHERE_VERBOSE" in os.environ: - del os.environ["SWMMANYWHERE_VERBOSE"] - - # Create a mock iterator - mock_iterator = iter(range(10)) - - with patch("swmmanywhere.logging.tqdm_original") as mock_tqdm: - # Call the custom tqdm function - result = [i for i in tqdm(mock_iterator)] - - mock_tqdm.assert_not_called() - - # Check if the progress_bar is the same as the mock_iterator - assert result == list(range(10)) \ No newline at end of file + del os.environ["SWMMANYWHERE_VERBOSE"] + assert not verbose() \ No newline at end of file From 8844bcab7ea5e177e2fd23d6fb51fe889454d463 Mon Sep 17 00:00:00 2001 From: barneydobson Date: Fri, 24 May 2024 10:11:55 +0100 Subject: [PATCH 13/15] Update post_processing.py --- swmmanywhere/post_processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/swmmanywhere/post_processing.py b/swmmanywhere/post_processing.py index 290a1e6c..5e437195 100644 --- a/swmmanywhere/post_processing.py +++ b/swmmanywhere/post_processing.py @@ -252,7 +252,6 @@ def data_dict_to_inp(data_dict: dict[str, np.ndarray], # Write the inp file for key, data in data_dict.items(): - # print(key) start_section = '[{0}]'.format(key) overwrite_section(data, start_section, new_input_file) From 01e8bd4450ece406930f5318ea760f9b89176259 Mon Sep 17 00:00:00 2001 From: barneydobson Date: Tue, 28 May 2024 14:52:13 +0100 Subject: [PATCH 14/15] Update swmmanywhere.py move to correct pos --- swmmanywhere/swmmanywhere.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index 8ef7744d..8245eec4 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -51,15 +51,16 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: logger.info(f"Setting {key} to {val}") setattr(addresses, key, val) - # Save config file - if verbose(): - save_config(config, addresses.model / 'config.yml') # Load the parameters and perform any manual overrides logger.info("Loading and setting parameters.") params = parameters.get_full_parameters() for category, overrides in config.get('parameter_overrides', {}).items(): for key, val in overrides.items(): setattr(params[category], key, val) + + # Save config file + if verbose(): + save_config(config, addresses.model / 'config.yml') # Run downloads logger.info("Running downloads.") From 217b1f67ea7093c4e035e1bde8192ab6f4189c01 Mon Sep 17 00:00:00 2001 From: barneydobson Date: Tue, 28 May 2024 14:55:50 +0100 Subject: [PATCH 15/15] convert to read_text --- swmmanywhere/prepare_data.py | 2 +- swmmanywhere/swmmanywhere.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/swmmanywhere/prepare_data.py b/swmmanywhere/prepare_data.py index 4338f709..ced87610 100644 --- a/swmmanywhere/prepare_data.py +++ b/swmmanywhere/prepare_data.py @@ -46,7 +46,7 @@ def get_country(x: float, geolocator = Nominatim(user_agent="get_iso") # Load ISO code mapping from YAML file - data = yaml_load(iso_path.open("r")) + data = yaml_load(iso_path.read_text()) # Get country ISO code from coordinates location = geolocator.reverse(f"{y}, {x}", exactly_one=True) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index 8245eec4..3df599b1 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -64,7 +64,7 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: # Run downloads logger.info("Running downloads.") - api_keys = yaml_load(config['api_keys'].open('r')) + api_keys = yaml_load(config['api_keys'].read_text()) preprocessing.run_downloads(config['bbox'], addresses, api_keys, @@ -304,10 +304,10 @@ def load_config(config_path: Path, validation: bool = True): """ # Load the schema schema_fid = Path(__file__).parent / 'defs' / 'schema.yml' - schema = yaml_load(schema_fid.open('r')) + schema = yaml_load(schema_fid.read_text()) # Load the config - config = yaml_load(config_path.open('r')) + config = yaml_load(config_path.read_text()) if not validation: return config