Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #145 from justusschock/bug_fixes
Browse files Browse the repository at this point in the history
Bug fixes
  • Loading branch information
justusschock authored Jun 16, 2019
2 parents 3c16d7e + 600bd04 commit 71368ce
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 83 deletions.
68 changes: 32 additions & 36 deletions delira/data_loading/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import numpy as np
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from queue import Empty
import logging

logger = logging.getLogger(__name__)

from .dataset import AbstractDataset
from .sampler import AbstractSampler, SequentialSampler


class BaseDataLoader(SlimDataLoaderBase):
Expand All @@ -12,8 +15,8 @@ class BaseDataLoader(SlimDataLoaderBase):
"""

def __init__(self, dataset: AbstractDataset,
batch_size=1, num_batches=None, seed=1,
sampler=None):
sampler_queues: list,
batch_size=1, num_batches=None, seed=1):
"""
Parameters
Expand All @@ -22,13 +25,13 @@ def __init__(self, dataset: AbstractDataset,
dataset to perform sample loading
batch_size : int
number of samples per batch
sampler_queues : list of :class:`multiprocessing.Queue`
the queue,s the sample indices to load will be put to.
Necessary for interprocess communication
num_batches : int
number of batches to load
seed : int
seed for Random Number Generator
sampler : AbstractSampler or None
class defining the sampling strategy;
if None: SequentialSampler will be used
Raises
------
Expand All @@ -45,24 +48,16 @@ class defining the sampling strategy;
# store dataset in self._data
super().__init__(dataset, batch_size)

assert isinstance(sampler, AbstractSampler) or sampler is None, \
"Sampler must be instance of subclass of AbstractSampler of None"

if sampler is None:
sampler = SequentialSampler(list(range(len(dataset))))

self.sampler = sampler
self.sampler_queues = sampler_queues

self.n_samples = len(sampler)
self.n_samples = len(dataset)
if num_batches is None:
num_batches = len(sampler) // batch_size
num_batches = len(dataset) // batch_size

self.num_batches = num_batches
self._seed = seed
np.random.seed(seed)

self._batches_generated = 0

def generate_train_batch(self):
"""
Generate Indices which behavior based on self.sampling gets data based
Expand All @@ -79,30 +74,31 @@ def generate_train_batch(self):
If the maximum number of batches has been generated
"""

if self._batches_generated >= self.num_batches:
raise StopIteration
else:
self._batches_generated += 1

idxs = self.sampler(self.batch_size)
idxs = None
sampler_queue = self.sampler_queues[self.thread_id]
while idxs is None:
try:
idxs = sampler_queue.get(timeout=0.2)

result = [self._get_sample(_idx) for _idx in idxs]
result = [self._get_sample(_idx) for _idx in idxs]

result_dict = {}
result_dict = {}

# concatenate dict entities by keys
for _result_dict in result:
for key, val in _result_dict.items():
if key in result_dict.keys():
result_dict[key].append(val)
else:
result_dict[key] = [val]
# concatenate dict entities by keys
for _result_dict in result:
for key, val in _result_dict.items():
if key in result_dict.keys():
result_dict[key].append(val)
else:
result_dict[key] = [val]

# convert list to numpy arrays
for key, val_list in result_dict.items():
result_dict[key] = np.asarray(val_list)
# convert list to numpy arrays
for key, val_list in result_dict.items():
result_dict[key] = np.asarray(val_list)

return result_dict
return result_dict
except Empty:
pass

def _get_sample(self, index):
"""
Expand Down
94 changes: 69 additions & 25 deletions delira/data_loading/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
SingleThreadedAugmenter, SlimDataLoaderBase
from batchgenerators.transforms import AbstractTransform

from multiprocessing import Queue
from queue import Full

from delira import get_current_debug_mode
from .data_loader import BaseDataLoader
from .dataset import AbstractDataset, BaseCacheDataset, BaseLazyDataset
Expand All @@ -24,8 +27,8 @@ class Augmenter(object):
"""

def __init__(self, data_loader: BaseDataLoader, transforms,
n_process_augmentation=None, num_cached_per_queue=2,
seeds=None, **kwargs):
n_process_augmentation, sampler, sampler_queues: list,
num_cached_per_queue=2, seeds=None, **kwargs):
"""
Parameters
Expand All @@ -37,6 +40,12 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
n_process_augmentation : int
the number of processes to use for augmentation (only necessary if
not in debug mode)
sampler : :class:`AbstractSampler`
the sampler to use; must be used here instead of inside the
dataloader to avoid duplications and oversampling due to
multiprocessing
sampler_queues : list of :class:`multiprocessing.Queue`
queues to pass the sample indices to the actual dataloader
num_cached_per_queue : int
the number of samples to cache per queue (only necessary if not in
debug mode)
Expand All @@ -45,6 +54,9 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
**kwargs :
additional keyword arguments
"""

self._batchsize = data_loader.batch_size

# don't use multiprocessing in debug mode
if get_current_debug_mode():
augmenter = SingleThreadedAugmenter(data_loader, transforms)
Expand Down Expand Up @@ -72,42 +84,59 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
**kwargs)

self._augmenter = augmenter
self._sampler = sampler
self._sampler_queues = sampler_queues
self._queue_id = 0

@property
def __iter__(self):
"""
Property returning the augmenters ``__iter__``
Function returning an iterator
Returns
-------
Callable
the augmenters ``__iter__``
Augmenter
self
"""
return self._augmenter.__iter__
return self

def _next_queue(self):
idx = self._queue_id
self._queue_id = (self._queue_id + 1) % len(self._sampler_queues)
return self._sampler_queues[idx]

@property
def __next__(self):
"""
Property returning the augmenters ``__next__``
Function to sample and load the next batch
Returns
-------
Callable
the augmenters ``__next__``
dict
the next batch
"""
return self._augmenter.__next__
idxs = self._sampler(self._batchsize)
queue = self._next_queue()

# dont't wait forever. Release this after short timeout and try again
# to avoid deadlock
while True:
try:
queue.put(idxs, timeout=0.2)
break
except Full:
continue

return next(self._augmenter)

@property
def next(self):
"""
Property returning the augmenters ``next``
Function to sample and load
Returns
-------
Callable
the augmenters ``next``
dict
the next batch
"""
return self._augmenter.next
return next(self)

@staticmethod
def __identity_fn(*args, **kwargs):
Expand Down Expand Up @@ -179,7 +208,6 @@ def restart(self):
"""
return self._fn_checker("restart")

@property
def _finish(self):
"""
Property to provide uniform API of ``_finish``
Expand All @@ -190,7 +218,12 @@ def _finish(self):
either the augmenter's ``_finish`` method (if available) or
``__identity_fn`` (if not available)
"""
return self._fn_checker("_finish")
ret_val = self._fn_checker("_finish")()
for queue in self._sampler_queues:
queue.close()
queue.join_thread()

return ret_val

@property
def num_batches(self):
Expand Down Expand Up @@ -230,6 +263,7 @@ def __del__(self):
Function defining what to do, if object should be deleted
"""
self._finish()
del self._augmenter


Expand Down Expand Up @@ -360,15 +394,23 @@ def get_batchgen(self, seed=1):
"""
assert self.n_batches > 0

data_loader = self.data_loader_cls(self.dataset,
batch_size=self.batch_size,
num_batches=self.n_batches,
seed=seed,
sampler=self.sampler
)
sampler_queues = []

for idx in range(self.n_process_augmentation):
sampler_queues.append(Queue())

data_loader = self.data_loader_cls(
self.dataset,
batch_size=self.batch_size,
num_batches=self.n_batches,
seed=seed,
sampler_queues=sampler_queues
)

return Augmenter(data_loader, self.transforms,
self.n_process_augmentation,
sampler=self.sampler,
sampler_queues=sampler_queues,
num_cached_per_queue=2,
seeds=self.n_process_augmentation * [seed])

Expand Down Expand Up @@ -528,6 +570,8 @@ def n_process_augmentation(self):
number of augmentation processes
"""

if get_current_debug_mode():
return 1
return self._n_process_augmentation

@n_process_augmentation.setter
Expand Down
2 changes: 1 addition & 1 deletion delira/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def reduce_fn(batch):
if "val_" + val_score_key not in total_metrics:
logger.warning(
"val_score_key '%s' not a valid key for \
validation metrics ")
validation metrics" % str(val_score_key))

new_val_score = best_val_score

Expand Down
7 changes: 4 additions & 3 deletions delira/training/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from datetime import datetime
from functools import partial
import copy

import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold, \
Expand Down Expand Up @@ -547,8 +548,8 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
train_data = data.get_subset(train_idxs)
test_data = data.get_subset(test_idxs)

train_data.update_state_from_dict(train_kwargs)
test_data.update_state_from_dict(test_kwargs)
train_data.update_state_from_dict(copy.deepcopy(train_kwargs))
test_data.update_state_from_dict(copy.deepcopy(test_kwargs))

val_data = None
if val_split is not None:
Expand All @@ -572,7 +573,7 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
for _train_idxs, _val_idxs in _val_split.split(train_idxs,
train_labels):
val_data = train_data.get_subset(_val_idxs)
val_data.update_state_from_dict(test_kwargs)
val_data.update_state_from_dict(copy.deepcopy(test_kwargs))

train_data = train_data.get_subset(_train_idxs)

Expand Down
Loading

0 comments on commit 71368ce

Please sign in to comment.