-
Notifications
You must be signed in to change notification settings - Fork 518
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #141 from Deci-AI/feature/SG-34_regseg_transfer_le…
…arning Feature/sg 34 regseg transfer learning
- Loading branch information
Showing
21 changed files
with
1,886 additions
and
331 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1,194 changes: 1,194 additions & 0 deletions
1,194
src/super_gradients/examples/SegmentationTransferLearning.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
58 changes: 58 additions & 0 deletions
58
...r_gradients/examples/regseg_transfer_learning_example/regseg_transfer_learning_example.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from super_gradients.training.datasets.dataset_interfaces.dataset_interface import SuperviselyPersonsDatasetInterface | ||
from super_gradients.training.sg_model import SgModel | ||
from super_gradients.training.metrics import BinaryIOU | ||
from super_gradients.training.transforms.transforms import ResizeSeg, RandomFlip, RandomRescale, CropImageAndMask, \ | ||
PadShortToCropSize, ColorJitterSeg | ||
from super_gradients.training.utils.callbacks import BinarySegmentationVisualizationCallback, Phase | ||
from torchvision import transforms | ||
|
||
# DEFINE DATA TRANSFORMATIONS | ||
dataset_params = { | ||
"image_mask_transforms_aug": transforms.Compose([ColorJitterSeg(brightness=0.5, contrast=0.5, saturation=0.5), | ||
RandomFlip(), | ||
RandomRescale(scales=[0.25, 1.]), | ||
PadShortToCropSize([320, 480]), | ||
CropImageAndMask(crop_size=[320, 480], | ||
mode="random")]), | ||
"image_mask_transforms": transforms.Compose([ResizeSeg(h=480, w=320)]) | ||
} | ||
|
||
dataset_interface = SuperviselyPersonsDatasetInterface(dataset_params) | ||
|
||
model = SgModel("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_epochs") | ||
|
||
# CONNECTING THE DATASET INTERFACE WILL SET SGMODEL'S CLASSES ATTRIBUTE ACCORDING TO SUPERVISELY | ||
model.connect_dataset_interface(dataset_interface) | ||
|
||
# THIS IS WHERE THE MAGIC HAPPENS- SINCE SGMODEL'S CLASSES ATTRIBUTE WAS SET TO BE DIFFERENT FROM CITYSCAPES'S, AFTER | ||
# LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING | ||
# TO OUR BINARY SEGMENTATION DATASET | ||
model.build_model("regseg48", arch_params={"pretrained_weights": "cityscapes"}) | ||
|
||
# DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST. | ||
train_params = {"max_epochs": 50, | ||
"lr_mode": "cosine", | ||
"initial_lr": 0.0064, # for batch_size=16 | ||
"optimizer_params": {"momentum": 0.843, | ||
"weight_decay": 0.00036, | ||
"nesterov": True}, | ||
|
||
"cosine_final_lr_ratio": 0.1, | ||
"multiply_head_lr": 10, | ||
"optimizer": "SGD", | ||
"loss": "bce_dice_loss", | ||
"ema": True, | ||
"zero_weight_decay_on_bias_and_bn": True, | ||
"average_best_models": True, | ||
"mixed_precision": False, | ||
"metric_to_watch": "mean_IOU", | ||
"greater_metric_to_watch_is_better": True, | ||
"train_metrics_list": [BinaryIOU()], | ||
"valid_metrics_list": [BinaryIOU()], | ||
"loss_logging_items_names": ["loss"], | ||
"phase_callbacks": [BinarySegmentationVisualizationCallback(phase=Phase.VALIDATION_BATCH_END, | ||
freq=1, | ||
last_img_idx_in_batch=4)], | ||
} | ||
|
||
model.train(train_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
...per_gradients/training/datasets/segmentation_datasets/supervisely_persons_segmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import csv | ||
import os | ||
|
||
from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet | ||
|
||
|
||
class SuperviselyPersonsDataset(SegmentationDataSet): | ||
""" | ||
SuperviselyPersonsDataset - Segmentation Data Set Class for Supervisely Persons Segmentation Data Set, | ||
main resolution of dataset: (600 x 800). | ||
This dataset is a subset of the original dataset (see below) and contains filtered samples | ||
For more details about the ORIGINAL dataset see: https://app.supervise.ly/ecosystem/projects/persons | ||
For more details about the FILTERED dataset see: | ||
https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/PP-HumanSeg | ||
""" | ||
CLASS_LABELS = {0: "background", 1: "person"} | ||
|
||
def __init__(self, root_dir: str, list_file: str, **kwargs): | ||
""" | ||
:param root_dir: root directory to dataset. | ||
:param list_file: list file that contains names of images to load, line format: <image_path>,<mask_path> | ||
:param kwargs: Any hyper params required for the dataset, i.e img_size, crop_size, etc... | ||
""" | ||
|
||
super().__init__(root=root_dir, list_file=list_file, **kwargs) | ||
self.classes = ['person'] | ||
|
||
def _generate_samples_and_targets(self): | ||
with open(os.path.join(self.root, self.list_file_path), 'r', encoding="utf-8") as file: | ||
reader = csv.reader(file) | ||
for row in reader: | ||
sample_path = os.path.join(self.root, row[0]) | ||
target_path = os.path.join(self.root, row[1]) | ||
if self._validate_file(sample_path) \ | ||
and self._validate_file(target_path) \ | ||
and os.path.exists(sample_path) \ | ||
and os.path.exists(target_path): | ||
self.samples_targets_tuples_list.append((sample_path, target_path)) | ||
else: | ||
raise AssertionError(f"Sample and/or target file(s) not found or in illegal format " | ||
f"(sample path: {sample_path}, target path: {target_path})") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
|
||
from super_gradients.training.losses.bce_loss import BCE | ||
from super_gradients.training.losses.dice_loss import BinaryDiceLoss | ||
|
||
|
||
class BCEDiceLoss(torch.nn.Module): | ||
""" | ||
Binary Cross Entropy + Dice Loss | ||
Weighted average of BCE and Dice loss | ||
Attributes: | ||
loss_weights: list of size 2 s.t loss_weights[0], loss_weights[1] are the weights for BCE, Dice | ||
respectively. | ||
""" | ||
def __init__(self, loss_weights=[0.5, 0.5], logits=True): | ||
super(BCEDiceLoss, self).__init__() | ||
self.loss_weights = loss_weights | ||
self.bce = BCE() | ||
self.dice = BinaryDiceLoss(apply_sigmoid=logits) | ||
|
||
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
@param input: Network's raw output shaped (N,1,H,W) | ||
@param target: Ground truth shaped (N,H,W) | ||
""" | ||
|
||
return self.loss_weights[0] * self.bce(input, target) + self.loss_weights[1] * self.dice(input, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import torch | ||
from torch.nn import BCEWithLogitsLoss | ||
|
||
|
||
class BCE(BCEWithLogitsLoss): | ||
""" | ||
Binary Cross Entropy Loss | ||
""" | ||
|
||
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
@param input: Network's raw output shaped (N,1,*) | ||
@param target: Ground truth shaped (N,*) | ||
""" | ||
return super(BCE, self).forward(input.squeeze(1), target.float()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.