Skip to content

Commit

Permalink
Merge pull request #141 from Deci-AI/feature/SG-34_regseg_transfer_le…
Browse files Browse the repository at this point in the history
…arning

Feature/sg 34 regseg transfer learning
  • Loading branch information
shaydeci authored Mar 10, 2022
2 parents dbdf898 + 827af56 commit 83634f5
Show file tree
Hide file tree
Showing 21 changed files with 1,886 additions and 331 deletions.
4 changes: 2 additions & 2 deletions src/super_gradients/common/factories/transforms_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from super_gradients.common.factories.base_factory import BaseFactory
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.training.utils.segmentation_utils import RandomFlip, Rescale, RandomRescale, RandomRotate, CropImageAndMask, RandomGaussianBlur, \
PadShortToCropSize, ColorJitterSeg
from super_gradients.training.transforms.transforms import RandomFlip, Rescale, RandomRescale, RandomRotate, \
CropImageAndMask, RandomGaussianBlur, PadShortToCropSize, ColorJitterSeg

from torchvision import transforms
import inspect
Expand Down
1,194 changes: 1,194 additions & 0 deletions src/super_gradients/examples/SegmentationTransferLearning.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
import argparse
import torchvision.transforms as transforms

from super_gradients.training.utils.segmentation_utils import RandomFlip, PadShortToCropSize, CropImageAndMask, RandomRescale
from super_gradients.training.transforms.transforms import RandomFlip, RandomRescale, CropImageAndMask, \
PadShortToCropSize
from super_gradients.training.losses.ddrnet_loss import DDRNetLoss
from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CITYSCAPES_IGNORE_LABEL
from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CityscapesDatasetInterface
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from super_gradients.training.datasets.dataset_interfaces.dataset_interface import SuperviselyPersonsDatasetInterface
from super_gradients.training.sg_model import SgModel
from super_gradients.training.metrics import BinaryIOU
from super_gradients.training.transforms.transforms import ResizeSeg, RandomFlip, RandomRescale, CropImageAndMask, \
PadShortToCropSize, ColorJitterSeg
from super_gradients.training.utils.callbacks import BinarySegmentationVisualizationCallback, Phase
from torchvision import transforms

# DEFINE DATA TRANSFORMATIONS
dataset_params = {
"image_mask_transforms_aug": transforms.Compose([ColorJitterSeg(brightness=0.5, contrast=0.5, saturation=0.5),
RandomFlip(),
RandomRescale(scales=[0.25, 1.]),
PadShortToCropSize([320, 480]),
CropImageAndMask(crop_size=[320, 480],
mode="random")]),
"image_mask_transforms": transforms.Compose([ResizeSeg(h=480, w=320)])
}

dataset_interface = SuperviselyPersonsDatasetInterface(dataset_params)

model = SgModel("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_epochs")

# CONNECTING THE DATASET INTERFACE WILL SET SGMODEL'S CLASSES ATTRIBUTE ACCORDING TO SUPERVISELY
model.connect_dataset_interface(dataset_interface)

# THIS IS WHERE THE MAGIC HAPPENS- SINCE SGMODEL'S CLASSES ATTRIBUTE WAS SET TO BE DIFFERENT FROM CITYSCAPES'S, AFTER
# LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
# TO OUR BINARY SEGMENTATION DATASET
model.build_model("regseg48", arch_params={"pretrained_weights": "cityscapes"})

# DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
train_params = {"max_epochs": 50,
"lr_mode": "cosine",
"initial_lr": 0.0064, # for batch_size=16
"optimizer_params": {"momentum": 0.843,
"weight_decay": 0.00036,
"nesterov": True},

"cosine_final_lr_ratio": 0.1,
"multiply_head_lr": 10,
"optimizer": "SGD",
"loss": "bce_dice_loss",
"ema": True,
"zero_weight_decay_on_bias_and_bn": True,
"average_best_models": True,
"mixed_precision": False,
"metric_to_watch": "mean_IOU",
"greater_metric_to_watch_is_better": True,
"train_metrics_list": [BinaryIOU()],
"valid_metrics_list": [BinaryIOU()],
"loss_logging_items_names": ["loss"],
"phase_callbacks": [BinarySegmentationVisualizationCallback(phase=Phase.VALIDATION_BATCH_END,
freq=1,
last_img_idx_in_batch=4)],
}

model.train(train_params)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from super_gradients.common.abstractions.abstract_logger import get_logger
from torch.utils.data.distributed import DistributedSampler
from super_gradients.training.datasets import datasets_utils, DataAugmentation
Expand All @@ -25,7 +24,9 @@
from pathlib import Path
from super_gradients.training.datasets.detection_datasets.pascal_voc_detection import PASCAL_VOC_2012_CLASSES
from super_gradients.training.utils.utils import download_and_unzip_from_url

from super_gradients.training.utils import get_param
import torchvision.transforms as transforms
from super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation import SuperviselyPersonsDataset
default_dataset_params = {"batch_size": 64, "val_batch_size": 200, "test_batch_size": 200, "dataset_dir": "./data/",
"s3_link": None}
LIBRARY_DATASETS = {
Expand Down Expand Up @@ -831,3 +832,31 @@ def convert_box(size, box):
lb_path = (lbs_path / f.name).with_suffix('.txt') # new label path
f.rename(imgs_path / f.name) # move image
convert_label(path, lb_path, year, id) # convert labels to YOLO format


class SuperviselyPersonsDatasetInterface(DatasetInterface):
def __init__(self, dataset_params=None, cache_labels: bool = False, cache_images: bool = False):
super().__init__(dataset_params=dataset_params)
root_dir = get_param(dataset_params, "dataset_dir", "/data/supervisely-persons")

self.trainset = SuperviselyPersonsDataset(
root_dir=root_dir,
list_file='train.csv',
dataset_hyper_params=dataset_params,
cache_labels=cache_labels,
cache_images=cache_images,
image_mask_transforms_aug=get_param(dataset_params, "image_mask_transforms_aug", transforms.Compose([])),
augment=True
)

self.valset = SuperviselyPersonsDataset(
root_dir=root_dir,
list_file='val.csv',
dataset_hyper_params=dataset_params,
cache_labels=cache_labels,
cache_images=cache_images,
image_mask_transforms=get_param(dataset_params, "image_mask_transforms", transforms.Compose([])),
augment=False
)

self.classes = self.trainset.classes
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transform
from super_gradients.training.utils.segmentation_utils import RandomFlip, CropImageAndMask, PadShortToCropSize,\
RandomRescale, Rescale
from super_gradients.training.transforms.transforms import RandomFlip, Rescale, RandomRescale, CropImageAndMask, \
PadShortToCropSize
from super_gradients.training.utils.utils import get_param

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training.datasets.sg_dataset import DirectoryDataSet, ListDataset
from super_gradients.training.utils.segmentation_utils import RandomFlip, Rescale, RandomRotate, PadShortToCropSize, \
CropImageAndMask, RandomGaussianBlur, RandomRescale
from super_gradients.training.transforms.transforms import RandomFlip, Rescale, RandomRescale, RandomRotate, \
CropImageAndMask, RandomGaussianBlur, PadShortToCropSize


class SegmentationDataSet(DirectoryDataSet, ListDataset):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import csv
import os

from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet


class SuperviselyPersonsDataset(SegmentationDataSet):
"""
SuperviselyPersonsDataset - Segmentation Data Set Class for Supervisely Persons Segmentation Data Set,
main resolution of dataset: (600 x 800).
This dataset is a subset of the original dataset (see below) and contains filtered samples
For more details about the ORIGINAL dataset see: https://app.supervise.ly/ecosystem/projects/persons
For more details about the FILTERED dataset see:
https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/PP-HumanSeg
"""
CLASS_LABELS = {0: "background", 1: "person"}

def __init__(self, root_dir: str, list_file: str, **kwargs):
"""
:param root_dir: root directory to dataset.
:param list_file: list file that contains names of images to load, line format: <image_path>,<mask_path>
:param kwargs: Any hyper params required for the dataset, i.e img_size, crop_size, etc...
"""

super().__init__(root=root_dir, list_file=list_file, **kwargs)
self.classes = ['person']

def _generate_samples_and_targets(self):
with open(os.path.join(self.root, self.list_file_path), 'r', encoding="utf-8") as file:
reader = csv.reader(file)
for row in reader:
sample_path = os.path.join(self.root, row[0])
target_path = os.path.join(self.root, row[1])
if self._validate_file(sample_path) \
and self._validate_file(target_path) \
and os.path.exists(sample_path) \
and os.path.exists(target_path):
self.samples_targets_tuples_list.append((sample_path, target_path))
else:
raise AssertionError(f"Sample and/or target file(s) not found or in illegal format "
f"(sample path: {sample_path}, target path: {target_path})")
3 changes: 2 additions & 1 deletion src/super_gradients/training/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from super_gradients.training.losses.yolo_v3_loss import YoLoV3DetectionLoss
from super_gradients.training.losses.yolo_v5_loss import YoLoV5DetectionLoss
from super_gradients.training.losses.ssd_loss import SSDLoss
from super_gradients.training.losses.bce_dice_loss import BCEDiceLoss
from super_gradients.training.losses.all_losses import LOSSES

__all__ = ['FocalLoss', 'LabelSmoothingCrossEntropyLoss', 'ShelfNetOHEMLoss', 'ShelfNetSemanticEncodingLoss',
'YoLoV3DetectionLoss', 'YoLoV5DetectionLoss', 'RSquaredLoss', 'SSDLoss', 'LOSSES']
'YoLoV3DetectionLoss', 'YoLoV5DetectionLoss', 'RSquaredLoss', 'SSDLoss', 'LOSSES', 'BCEDiceLoss']
3 changes: 2 additions & 1 deletion src/super_gradients/training/losses/all_losses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn

from super_gradients.training.losses import LabelSmoothingCrossEntropyLoss, YoLoV3DetectionLoss, ShelfNetOHEMLoss, \
ShelfNetSemanticEncodingLoss, RSquaredLoss, YoLoV5DetectionLoss, SSDLoss
ShelfNetSemanticEncodingLoss, RSquaredLoss, YoLoV5DetectionLoss, SSDLoss, BCEDiceLoss
from super_gradients.training.losses.stdc_loss import STDCLoss

LOSSES = {"cross_entropy": LabelSmoothingCrossEntropyLoss,
Expand All @@ -13,4 +13,5 @@
"yolo_v5_loss": YoLoV5DetectionLoss,
"ssd_loss": SSDLoss,
"stdc_loss": STDCLoss,
"bce_dice_loss": BCEDiceLoss
}
30 changes: 30 additions & 0 deletions src/super_gradients/training/losses/bce_dice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch

from super_gradients.training.losses.bce_loss import BCE
from super_gradients.training.losses.dice_loss import BinaryDiceLoss


class BCEDiceLoss(torch.nn.Module):
"""
Binary Cross Entropy + Dice Loss
Weighted average of BCE and Dice loss
Attributes:
loss_weights: list of size 2 s.t loss_weights[0], loss_weights[1] are the weights for BCE, Dice
respectively.
"""
def __init__(self, loss_weights=[0.5, 0.5], logits=True):
super(BCEDiceLoss, self).__init__()
self.loss_weights = loss_weights
self.bce = BCE()
self.dice = BinaryDiceLoss(apply_sigmoid=logits)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
@param input: Network's raw output shaped (N,1,H,W)
@param target: Ground truth shaped (N,H,W)
"""

return self.loss_weights[0] * self.bce(input, target) + self.loss_weights[1] * self.dice(input, target)
16 changes: 16 additions & 0 deletions src/super_gradients/training/losses/bce_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
from torch.nn import BCEWithLogitsLoss


class BCE(BCEWithLogitsLoss):
"""
Binary Cross Entropy Loss
"""

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
@param input: Network's raw output shaped (N,1,*)
@param target: Ground truth shaped (N,*)
"""
return super(BCE, self).forward(input.squeeze(1), target.float())
13 changes: 13 additions & 0 deletions src/super_gradients/training/metrics/segmentation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,16 @@ def update(self, preds, target: torch.Tensor):
preds = preds[0]
_, preds = torch.max(preds, 1)
super().update(preds=preds, target=target)


class BinaryIOU(torchmetrics.IoU):
def __init__(self, dist_sync_on_step=True, ignore_index=None):
super().__init__(num_classes=2, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction="none", threshold=0.5)
self.component_names = ["target_IOU", "background_IOU", "mean_IOU"]

def update(self, preds, target: torch.Tensor):
super().update(preds=torch.sigmoid(preds), target=target.long())

def compute(self):
ious = super(BinaryIOU, self).compute()
return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()}
24 changes: 24 additions & 0 deletions src/super_gradients/training/models/segmentation_models/regseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,30 @@ def forward(self, x):
x = self.head(x)
return x

def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
multiply_head_lr = get_param(training_params, "multiply_head_lr", 1)
multiply_lr_params, no_multiply_params = {}, {}
for name, param in self.named_parameters():
if "head." in name:
multiply_lr_params[name] = param
else:
no_multiply_params[name] = param

multiply_lr_params, no_multiply_params = multiply_lr_params.items(), no_multiply_params.items()

param_groups = [{"named_params": no_multiply_params, "lr": lr, "name": "no_multiply_params"},
{"named_params": multiply_lr_params, "lr": lr * multiply_head_lr, "name": "multiply_lr_params"}]
return param_groups

def update_param_groups(self, param_groups: list, lr: float, epoch: int, iter: int, training_params: HpmStruct,
total_batch: int) -> list:
multiply_head_lr = get_param(training_params, "multiply_head_lr", 1)
for param_group in param_groups:
param_group['lr'] = lr
if param_group["name"] == "multiply_lr_params":
param_group['lr'] *= multiply_head_lr
return param_groups

def replace_head(self, new_num_classes: int, head_config: dict):
self.head = RegSegHead(self.decoder.out_channels, new_num_classes, head_config)

Expand Down
Empty file.
Loading

0 comments on commit 83634f5

Please sign in to comment.