-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from HealthML/feature/improve_metric_calculations
Feature/improve metric calculations
- Loading branch information
Showing
19 changed files
with
693 additions
and
271 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""Module defining hooks that each dataset class should implement""" | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Iterable, List, Union | ||
|
||
|
||
class DatasetHooks(ABC): | ||
""" | ||
Class that defines hooks that should be implemented by each dataset class. | ||
""" | ||
|
||
@abstractmethod | ||
def image_ids(self) -> Iterable[str]: | ||
""" | ||
Returns: | ||
List of all image IDs included in the dataset. | ||
""" | ||
|
||
@abstractmethod | ||
def slices_per_image(self, **kwargs) -> Union[int, List[int]]: | ||
""" | ||
Args: | ||
kwargs: Dataset specific parameters. | ||
Returns: | ||
Union[int, List[int]]: Number of slices that each image of the dataset contains. If a single integer | ||
value is provided, it is assumed that all images of the dataset have the same number of slices. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Module to load and batch Pascal VOC segmentation dataset""" | ||
|
||
from typing import Any, Iterable, Tuple | ||
|
||
from torch.utils.data import Dataset | ||
from torchvision import datasets | ||
|
||
from .dataset_hooks import DatasetHooks | ||
|
||
|
||
class PascalVOCDataset(Dataset, DatasetHooks): | ||
""" | ||
Wrapper class for the VOCSegmentation dataset class from the torchvision package. | ||
Args: | ||
root (str): Root directory of the VOC Dataset. | ||
kwargs: Additional keyword arguments as defined in the VOCSegmentation class from the torchvision package. | ||
""" | ||
|
||
def __init__(self, root: str, **kwargs): | ||
|
||
self.pascal_voc_datset = datasets.VOCSegmentation(root, **kwargs) | ||
|
||
def __getitem__(self, index: int) -> Tuple[Any, Any]: | ||
return self.pascal_voc_datset.__getitem__(index) | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Returns: | ||
int: Length of the dataset. | ||
""" | ||
|
||
return self.pascal_voc_datset.__len__() | ||
|
||
def image_ids(self) -> Iterable[str]: | ||
""" | ||
Returns: | ||
List of all image IDs included in the dataset. | ||
""" | ||
|
||
return range(self.__len__) | ||
|
||
def slices_per_image(self, **kwargs) -> int: | ||
""" | ||
Returns: | ||
int: Number of slices that each image of the dataset contains. | ||
""" | ||
|
||
return 1 |
Oops, something went wrong.