Skip to content

Commit

Permalink
Merge pull request #45 from HealthML/feature/improve_metric_calculations
Browse files Browse the repository at this point in the history
Feature/improve metric calculations
  • Loading branch information
josafatburmeister authored Dec 18, 2021
2 parents acacb80 + 1e4a05f commit afd18b9
Show file tree
Hide file tree
Showing 19 changed files with 693 additions and 271 deletions.
8 changes: 7 additions & 1 deletion brats_example_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
"input_shape": [
240,
240
]
],
"model_selection_criterion": "mean_dice_score_0.5",
"train_metrics": ["dice_score"],
"train_metric_confidence_levels": [0.25, 0.5, 0.75],
"test_metrics": ["dice_score", "sensitivity", "specificity", "hausdorff95"],
"test_metric_confidence_levels": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
},
"dataset_config": {
"dataset": "brats",
Expand All @@ -17,6 +22,7 @@
"pin_memory": true
},
"wandb_project_name": "active-segmentation-tests",
"checkpoint_dir": "/dhc/groups/mpws2021cl1/Models",
"strategy": "base",
"experiment_name": "test-experiment",
"experiment_tags": [],
Expand Down
23 changes: 22 additions & 1 deletion src/active_learning.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
""" Module containing the active learning pipeline """

from typing import Iterable, Union
from typing import Iterable, Optional, Union

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

from query_strategies import QueryStrategy
from datasets import ActiveLearningDataModule
Expand All @@ -20,6 +21,7 @@ class ActiveLearningPipeline:
strategy (QueryStrategy): An active learning strategy to query for new labels.
epochs (int): The number of epochs the model should be trained.
gpus (int): Number of GPUS to use for model training.
checkpoint_dir (str, optional): Directory where the model checkpoints are to be saved.
early_stopping (bool, optional): Enable/Disable Early stopping when model
is not learning anymore (default = False).
logger: A logger object as defined by Pytorch Lightning.
Expand All @@ -35,9 +37,11 @@ def __init__(
strategy: QueryStrategy,
epochs: int,
gpus: int,
checkpoint_dir: Optional[str] = None,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
early_stopping: bool = False,
lr_scheduler: str = None,
model_selection_criterion="loss",
) -> None:

self.data_module = data_module
Expand All @@ -51,6 +55,20 @@ def __init__(
if early_stopping:
callbacks.append(EarlyStopping("validation/loss"))

monitoring_mode = "min" if "loss" in model_selection_criterion else "max"

self.checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir,
filename="best_model_epoch_{epoch}",
auto_insert_metric_name=False,
monitor=f"val/{model_selection_criterion}",
mode=monitoring_mode,
save_last=True,
every_n_epochs=1,
save_on_train_epoch_end=False,
)
callbacks.append(self.checkpoint_callback)

self.model_trainer = Trainer(
deterministic=True,
profiler="simple",
Expand All @@ -77,3 +95,6 @@ def run(self) -> None:
self.data_module.label_items(items_to_label)

self.model_trainer.fit(self.model, self.data_module)

# compute metrics for the best model on the validation set
self.model_trainer.validate()
1 change: 1 addition & 0 deletions src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from datasets.brats_data_module import BraTSDataModule
from datasets.pascal_voc_data_module import PascalVOCDataModule
from datasets.brats_dataset import BraTSDataset
from datasets.pascal_voc_dataset import PascalVOCDataset
2 changes: 2 additions & 0 deletions src/datasets/brats_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ def __init__(

def label_items(self, ids: List[str], labels: Optional[Any] = None) -> None:
"""TBD"""

# ToDo: implement labeling logic
return None

def _create_training_set(self) -> Optional[Dataset]:
"""Creates a training dataset."""

train_image_paths, train_annotation_paths = BraTSDataModule.discover_paths(
os.path.join(self.data_folder, "train")
)
Expand Down
21 changes: 19 additions & 2 deletions src/datasets/brats_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
""" Module to load and batch brats dataset """
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Iterable, List, Optional, Tuple
import math
from multiprocessing import Manager
import os
Expand All @@ -9,9 +9,10 @@
import torch
from torch.utils.data import IterableDataset

from .dataset_hooks import DatasetHooks

# pylint: disable=too-many-instance-attributes,abstract-method
class BraTSDataset(IterableDataset):
class BraTSDataset(IterableDataset, DatasetHooks):
"""
The BraTS dataset is published in the course of the annual MultimodalBrainTumorSegmentation Challenge (BraTS)
held since 2012. It is composed of 3T multimodal MRI scans from patients affected by glioblastoma or lower grade
Expand Down Expand Up @@ -322,3 +323,19 @@ def remove_image(self, image_path: str, annotation_path: str) -> None:
self.num_images -= 1
else:
raise ValueError("Image does not belong to this dataset.")

def image_ids(self) -> Iterable[str]:
"""
Returns:
List of all image IDs included in the dataset.
"""

return [self.__get_case_id(image_path) for image_path in self.image_paths]

def slices_per_image(self, **kwargs) -> int:
"""
Returns:
int: Number of slices that each image of the dataset contains.
"""

return BraTSDataset.IMAGE_DIMENSIONS[0]
27 changes: 27 additions & 0 deletions src/datasets/dataset_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Module defining hooks that each dataset class should implement"""

from abc import ABC, abstractmethod
from typing import Iterable, List, Union


class DatasetHooks(ABC):
"""
Class that defines hooks that should be implemented by each dataset class.
"""

@abstractmethod
def image_ids(self) -> Iterable[str]:
"""
Returns:
List of all image IDs included in the dataset.
"""

@abstractmethod
def slices_per_image(self, **kwargs) -> Union[int, List[int]]:
"""
Args:
kwargs: Dataset specific parameters.
Returns:
Union[int, List[int]]: Number of slices that each image of the dataset contains. If a single integer
value is provided, it is assumed that all images of the dataset have the same number of slices.
"""
7 changes: 4 additions & 3 deletions src/datasets/pascal_voc_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import Any, List, Optional
import numpy as np
from torch.utils.data import Dataset, random_split
from torchvision import datasets, transforms
from torchvision import transforms
import torch

from datasets.data_module import ActiveLearningDataModule
from .pascal_voc_dataset import PascalVOCDataset


class PILMaskToTensor:
Expand Down Expand Up @@ -65,7 +66,7 @@ def label_items(self, ids: List[str], labels: Optional[Any] = None) -> None:

def _create_training_set(self) -> Optional[Dataset]:
"""Creates a training dataset."""
training_set = datasets.VOCSegmentation(
training_set = PascalVOCDataset(
self.data_folder,
year="2012",
image_set="train",
Expand All @@ -80,7 +81,7 @@ def _create_training_set(self) -> Optional[Dataset]:

def _create_validation_set(self) -> Optional[Dataset]:
"""Creates a validation dataset."""
validation_set = datasets.VOCSegmentation(
validation_set = PascalVOCDataset(
self.data_folder,
year="2012",
image_set="val",
Expand Down
49 changes: 49 additions & 0 deletions src/datasets/pascal_voc_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Module to load and batch Pascal VOC segmentation dataset"""

from typing import Any, Iterable, Tuple

from torch.utils.data import Dataset
from torchvision import datasets

from .dataset_hooks import DatasetHooks


class PascalVOCDataset(Dataset, DatasetHooks):
"""
Wrapper class for the VOCSegmentation dataset class from the torchvision package.
Args:
root (str): Root directory of the VOC Dataset.
kwargs: Additional keyword arguments as defined in the VOCSegmentation class from the torchvision package.
"""

def __init__(self, root: str, **kwargs):

self.pascal_voc_datset = datasets.VOCSegmentation(root, **kwargs)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
return self.pascal_voc_datset.__getitem__(index)

def __len__(self) -> int:
"""
Returns:
int: Length of the dataset.
"""

return self.pascal_voc_datset.__len__()

def image_ids(self) -> Iterable[str]:
"""
Returns:
List of all image IDs included in the dataset.
"""

return range(self.__len__)

def slices_per_image(self, **kwargs) -> int:
"""
Returns:
int: Number of slices that each image of the dataset contains.
"""

return 1
Loading

0 comments on commit afd18b9

Please sign in to comment.