Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

remove timestamp from save path #200

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
17 changes: 10 additions & 7 deletions delira/training/base_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import pickle
import os
from datetime import datetime
import warnings

import copy
Expand All @@ -17,6 +16,7 @@
from delira.models import AbstractNetwork

from delira.utils import DeliraConfig
from delira.training.utils import generate_save_path
from delira.training.base_trainer import BaseNetworkTrainer
from delira.training.predictor import Predictor

Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(self,
checkpoint_freq=1,
trainer_cls=BaseNetworkTrainer,
predictor_cls=Predictor,
unique_name=True,
**kwargs):
"""

Expand Down Expand Up @@ -87,6 +88,9 @@ def __init__(self,
the trainer class to use for training the model
predictor_cls : subclass of :class:`Predictor`
the predictor class to use for testing the model
unique_name : boolean
if the name is not unique an experiment with the same
name will be continued
**kwargs :
additional keyword arguments

Expand All @@ -109,12 +113,11 @@ def __init__(self,
if save_path is None:
save_path = os.path.abspath(".")

self.save_path = os.path.join(save_path, name,
str(datetime.now().strftime(
"%y-%m-%d_%H-%M-%S")))

if os.path.isdir(self.save_path):
logger.warning("Save Path %s already exists")
if unique_name:
self.save_path = generate_save_path(
os.path.join(save_path, name))
else:
self.save_path = os.path.join(save_path, name)

os.makedirs(self.save_path, exist_ok=True)

Expand Down
18 changes: 18 additions & 0 deletions delira/training/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import numpy as np
import os
from datetime import datetime


def recursively_convert_elements(element, check_type, conversion_fn):
Expand Down Expand Up @@ -98,3 +100,19 @@ def convert_to_numpy_identity(*args, **kwargs):
_correct_zero_shape)

return args, kwargs


def generate_save_path(save_path):
i = 0
gedoensmax marked this conversation as resolved.
Show resolved Hide resolved
now = datetime.now()
date_str = '{}_{:02d}_{:02d}_'.format(
now.year, now.month, now.day)
while True:
new_path = os.path.join(save_path, '{}{:03d}'.format(date_str, i))
i += 1
if not os.path.isdir(new_path):
break
if i:
print('Save path is a duplicate and got changed to {}'
.format(new_path))
return new_path