diff --git a/CHANGELOG.md b/CHANGELOG.md index ebb578ae2..90b17cb1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.10] + +### Fixed + +- When loading parameters, SockeyeModel now ignores false positive missing parameters for traced modules. These modules use the same parameters as their original non-traced versions. + ## [3.1.9] ### Changed diff --git a/sockeye/__init__.py b/sockeye/__init__.py index e0ef9df6c..8bed14975 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.9' +__version__ = '3.1.10' diff --git a/sockeye/model.py b/sockeye/model.py index 148d0ae6e..8cb7d2bed 100644 --- a/sockeye/model.py +++ b/sockeye/model.py @@ -362,6 +362,11 @@ def load_parameters(self, # Earlier versions of Sockeye may have saved parameters for traced # modules. These parameters can be safely ignored. unexpected = [key for key in unexpected if 'traced' not in key] + # We also ignore cases where traced modules exist and appear to be + # missing parameters. These modules actually use the same parameters as + # their original non-traced versions so there are no separate parameters + # to load. + missing = [key for key in missing if 'traced' not in key] if not allow_missing: utils.check_condition(not missing, f"missing keys: {missing}") if not ignore_extra: diff --git a/test/integration/test_seq_copy_int.py b/test/integration/test_seq_copy_int.py index a7e252d96..c698510c4 100644 --- a/test/integration/test_seq_copy_int.py +++ b/test/integration/test_seq_copy_int.py @@ -231,7 +231,10 @@ def _test_parameter_averaging(model_path: str): def _test_checkpoint_decoder(dev_source_path: str, dev_target_path: str, model_path: str): """ - Runs checkpoint decoder on 10% of the dev data and checks whether metric keys are present in the result dict. + Runs checkpoint decoder on 10% of the dev data and checks whether metric + keys are present in the result dict. Also checks that we can reload model + parameters after running the checkpoint decoder (case when using the + plateau-reduce scheduler). """ with open(dev_source_path) as dev_fd: num_dev_sent = sum(1 for _ in dev_fd) @@ -254,3 +257,5 @@ def _test_checkpoint_decoder(dev_source_path: str, dev_target_path: str, model_p assert 'bleu' in cp_metrics assert 'chrf' in cp_metrics assert 'decode-walltime' in cp_metrics + + model.load_parameters(os.path.join(model_path, C.PARAMS_BEST_NAME), device=pt.device('cpu'))