Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #215 from delira-dev/register_logger
Browse files Browse the repository at this point in the history
Register logger
  • Loading branch information
mibaumgartner authored Oct 8, 2019
2 parents 47d2a14 + 84432a4 commit df0a7e4
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 9 deletions.
21 changes: 16 additions & 5 deletions delira/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .predictor import Predictor
from ..data_loading import Augmenter, DataManager
from ..models import AbstractNetwork
from ..logging import register_logger, make_logger

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -322,8 +323,11 @@ def _at_iter_begin(self, iter_num, epoch=0, **kwargs):
"""
for cb in self._callbacks:
self._update_state(cb.at_iter_begin(
self, iter_num=iter_num, curr_epoch=epoch,
global_iter_num=self._global_iter_num, **kwargs,
self, iter_num=iter_num,
curr_epoch=epoch,
global_iter_num=self._global_iter_num,
train=True,
**kwargs,
))

def _at_iter_end(self, iter_num, data_dict, metrics, epoch=0, **kwargs):
Expand All @@ -347,9 +351,12 @@ def _at_iter_end(self, iter_num, data_dict, metrics, epoch=0, **kwargs):

for cb in self._callbacks:
self._update_state(cb.at_iter_end(
self, iter_num=iter_num, data_dict=data_dict,
metrics=metrics, curr_epoch=epoch,
self, iter_num=iter_num,
data_dict=data_dict,
metrics=metrics,
curr_epoch=epoch,
global_iter_num=self._global_iter_num,
train=True,
**kwargs,
))

Expand Down Expand Up @@ -833,12 +840,16 @@ def _reinitialize_logging(self, logging_type, logging_kwargs: dict,

level = _logging_kwargs.pop("level")

logger = backend_cls(_logging_kwargs)

self.register_callback(
logging_callback_cls(
backend_cls(logging_kwargs), level=level,
logger, level=level,
logging_frequencies=logging_frequencies,
reduce_types=reduce_types))

register_logger(self._callbacks[-1]._logger, self.name)

@staticmethod
def _search_for_prev_state(path, extensions=None):
"""
Expand Down
5 changes: 3 additions & 2 deletions delira/training/callbacks/abstract_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, *args, **kwargs):
keyword arguments
"""
pass
super().__init__(*args, **kwargs)

def at_epoch_begin(self, trainer, *args, **kwargs):
"""
Expand Down Expand Up @@ -124,7 +124,7 @@ def at_iter_begin(self, trainer, *args, **kwargs):
Notes
-----
The predictor calls the callbacks with the following additional
arguments: `iter_num`(int)
arguments: `iter_num`(int), `train`(bool)
The basetrainer adds following arguments (wrt the predictor):
`curr_epoch`(int), `global_iter_num`(int)
Expand Down Expand Up @@ -153,6 +153,7 @@ def at_iter_end(self, trainer, *args, **kwargs):
The predictor calls the callbacks with the following additional
arguments: `iter_num`(int), `metrics`(dict),
`data_dict`(dict, contains prediction and input data),
`train`(bool)
The basetrainer adds following arguments (wrt the predictor):
`curr_epoch`(int), `global_iter_num`(int)
Expand Down
14 changes: 12 additions & 2 deletions delira/training/callbacks/logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, backend: BaseBackend, max_queue_size: int = None,
logging_frequencies=logging_frequencies,
reduce_types=reduce_types, level=level)

def at_iter_end(self, trainer, iter_num=None, data_dict=None, **kwargs):
def at_iter_end(self, trainer, iter_num=None, data_dict=None, train=False,
**kwargs):
"""
Function logging the metrics at the end of each iteration
Expand All @@ -63,6 +64,8 @@ def at_iter_end(self, trainer, iter_num=None, data_dict=None, **kwargs):
(unused in this callback)
data_dict : dict
the current data dict (including predictions)
train: bool
signals if callback is called by trainer or predictor
**kwargs :
additional keyword arguments
Expand All @@ -76,7 +79,14 @@ def at_iter_end(self, trainer, iter_num=None, data_dict=None, **kwargs):
global_step = kwargs.get("global_iter_num", None)

for k, v in metrics.items():
self._logger.log({"scalar": {"tag": k, "scalar_value": v,
self._logger.log({"scalar": {"tag": self.create_tag(k, train),
"scalar_value": v,
"global_step": global_step}})

return {}

@staticmethod
def create_tag(tag: str, train: bool):
if train:
tag = tag + "_val"
return tag
2 changes: 2 additions & 0 deletions delira/training/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def _at_iter_begin(self, iter_num, **kwargs):
for cb in self._callbacks:
return_dict.update(cb.at_iter_begin(self,
iter_num=iter_num,
train=False,
**kwargs))

return return_dict
Expand Down Expand Up @@ -208,6 +209,7 @@ def _at_iter_end(self, iter_num, data_dict, metrics, **kwargs):
iter_num=iter_num,
data_dict=data_dict,
metrics=metrics,
train=False,
**kwargs))

return return_dict
Expand Down
74 changes: 74 additions & 0 deletions tests/logging/test_logging_outside_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import unittest
from delira.logging import log
from delira.training import BaseNetworkTrainer
from delira.models import AbstractNetwork
import os
from tests.utils import check_for_tf_graph_backend

try:
import tensorflow as tf
except ImportError:
tf = None


class LoggingOutsideTrainerTestCase(unittest.TestCase):

@unittest.skipUnless(check_for_tf_graph_backend(),
"TF Backend not installed")
def test_logging_freq(self):
save_path = os.path.abspath("./logs")
config = {
"num_epochs": 2,
"losses": {},
"optimizer_cls": None,
"optimizer_params": {"learning_rate": 1e-3},
"metrics": {},
"lr_scheduler_cls": None,
"lr_scheduler_params": {}
}
trainer = BaseNetworkTrainer(
AbstractNetwork(),
save_path,
**config,
gpu_ids=[],
save_freq=1,
optim_fn=None,
key_mapping={},
logging_type="tensorboardx",
logging_kwargs={
'logdir': save_path
})

trainer._setup(
AbstractNetwork(),
lr_scheduler_cls=None,
lr_scheduler_params={},
gpu_ids=[],
key_mapping={},
convert_batch_to_npy_fn=None,
prepare_batch_fn=None,
callbacks=[])

tag = 'dummy'

log({"scalar": {"scalar_value": 1234, "tag": tag}})

file = [os.path.join(save_path, x)
for x in os.listdir(save_path)
if os.path.isfile(os.path.join(save_path, x))][0]

ret_val = False
if tf is not None:
for e in tf.train.summary_iterator(file):
for v in e.summary.value:
if v.tag == tag:
ret_val = True
break
if ret_val:
break

self.assertTrue(ret_val)


if __name__ == '__main__':
unittest.main()

0 comments on commit df0a7e4

Please sign in to comment.