Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardization of Initial Weights (SIW) Plugin implementation #527

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions avalanche/training/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
ExperienceBalancedStoragePolicy
from .strategy_plugin import StrategyPlugin
from .synaptic_intelligence import SynapticIntelligencePlugin
from .siw import SIWPlugin
from .cope import CoPEPlugin, PPPloss
149 changes: 149 additions & 0 deletions avalanche/training/plugins/siw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from torch.utils.data import random_split, ConcatDataset
from avalanche.benchmarks.utils import AvalancheConcatDataset
from avalanche.training.plugins.strategy_plugin import StrategyPlugin
from avalanche.benchmarks.utils.data_loader import \
MultiTaskJoinedBatchDataLoader
import torch
import torch.cuda as tc
from torch.autograd import Variable
import torch.nn as nn
from avalanche.training.utils import get_last_fc_layer, get_layer_by_name
from typing import Optional
from torch.nn import Linear


class SIWPlugin(StrategyPlugin):
"""
Standardization of Initial Weights (SIW) plugin.
From https://arxiv.org/pdf/2008.13710.pdf

Performs past class initial weights replay and state-level score
calibration. The callbacks `before_training_exp`, `after_backward`,
`after_training_exp`,`before_eval_exp`, and `after_eval_forward`
are implemented.

The `before_training_exp` callback is implemented in order to keep
track of the classes in each experience

The `after_backward` callback is implemented in order to freeze past
class weights in the last fully connected layer

The `after_training_exp` callback is implemented in order to extract
new class images' scores and compute the model confidence at
each incremental state.

The `before_eval_exp` callback is implemented in order to standardize
initial weights before inference

The`after_eval_forward` is implemented in order to apply state-level
calibration at the inference time

The :siw_layer_name: parameter concerns the name of the last fully
connected layer of the network

The :batch_size: and :num_workers: parameters concern the new class
scores extraction.
"""

def __init__(self, model, siw_layer_name='fc', batch_size=32,
num_workers=0):
super().__init__()
self.confidences = []
self.classes_per_experience = []
self.model = model
self.siw_layer_name = siw_layer_name
self.num_workers = num_workers
self.batch_size = batch_size

def get_siw_layer(self) -> Optional[Linear]:
result = None
if self.siw_layer_name is None:
last_fc = get_last_fc_layer(self.model)
if last_fc is not None:
result = last_fc[1]
else:
result = get_layer_by_name(self.model, self.siw_layer_name)
return result

def before_training_exp(self, strategy, **kwargs):
"""
Keep track of the classes encountered in each experience
"""
self.classes_per_experience.append(
strategy.experience.classes_in_this_experience)

def after_backward(self, strategy, **kwargs):
"""
Before executing the optimization step to perform
back-propagation, we zero the gradients of past class
weights and bias. This is equivalent to freeze past
class weights and bias, to let only the feature extractor
and the new class weights and bias evolve
"""
previous_classes = len(strategy.experience.previous_classes)
last_layer = self.get_siw_layer()
if last_layer is None:
raise RuntimeError('Can\'t find this Linear layer')

last_layer.weight.grad[:previous_classes, :] = 0
last_layer.bias.grad[:previous_classes] = 0

@torch.no_grad()
def after_training_exp(self, strategy, **kwargs):
"""
Before evaluating the performance of our model,
we extract new class images' scores and compute the
model's confidence at each incremental state
"""
# extract training scores
strategy.model.eval()

dataset = strategy.experience.dataset
loader = torch.utils.data.DataLoader(
dataset, batch_size=self.batch_size,
num_workers=self.num_workers)

# compute model's confidence
max_top1_scores = []
for i, data in enumerate(loader):
inputs, targets, task_labels = data
if tc.is_available():
inputs = inputs.to(strategy.device)
logits = strategy.model(inputs)
max_score = torch.max(logits, dim=1)[0].tolist()
max_top1_scores.extend(max_score)
self.confidences.append(sum(max_top1_scores) /
len(max_top1_scores))

@torch.no_grad()
def before_eval_exp(self, strategy, **kwargs):
"""
Standardize all class weights (by subtracting their mean
and dividing by their standard deviation)
"""

# standardize last layer weights
last_layer = self.get_siw_layer()
if last_layer is None:
raise RuntimeError('Can\'t find this Linear layer')

classes_seen_so_far = len(strategy.experience.classes_seen_so_far)

for i in range(classes_seen_so_far):
mu = torch.mean(last_layer.weight[i])
std = torch.std(last_layer.weight[i])

last_layer.weight.data[i] -= mu
last_layer.weight.data[i] /= std

def after_eval_forward(self, strategy, **kwargs):
"""
Rectify past class scores by multiplying them by the model's
confidence in the current state and dividing them by the
model's confidence in the initial state in which a past
class was encountered for the first time
"""
for exp in range(len(self.confidences)):
strategy.logits[:, self.classes_per_experience[exp]] *= \
self.confidences[strategy.experience.current_experience] \
/ self.confidences[exp]
50 changes: 49 additions & 1 deletion avalanche/training/strategies/strategy_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from avalanche.training import default_logger
from avalanche.training.plugins import StrategyPlugin, CWRStarPlugin, \
ReplayPlugin, GDumbPlugin, LwFPlugin, AGEMPlugin, GEMPlugin, EWCPlugin, \
EvaluationPlugin, SynapticIntelligencePlugin, CoPEPlugin
EvaluationPlugin, SynapticIntelligencePlugin, SIWPlugin, CoPEPlugin
from avalanche.training.strategies.base_strategy import BaseStrategy


Expand Down Expand Up @@ -450,6 +450,53 @@ def __init__(self, model: Module, optimizer: Optimizer, criterion,
)


class SIW(BaseStrategy):
def __init__(self, model: Module, optimizer: Optimizer, criterion,
siw_layer_name: str = 'fc',
batch_size: int = 32, num_workers: int = 0,
train_mb_size: int = 1, train_epochs: int = 1,
eval_mb_size: int = None, device=None,
plugins: Optional[List[StrategyPlugin]] = None,
evaluator: EvaluationPlugin = default_logger, eval_every=-1):
""" Standardization of Initial Weights (SIW) strategy.
See SIW plugin for details.
This strategy does not use task identities.

:param model: The model.
:param optimizer: The optimizer to use.
:param criterion: The loss criterion to use.
:param siw_layer_name: The name of the last fully connected layer
:param num_workers: The number of workers used to load batches
:param batch_size: The batch size used to extract scores
:param train_mb_size: The train minibatch size. Defaults to 1.
:param train_epochs: The number of training epochs. Defaults to 1.
:param eval_mb_size: The eval minibatch size. Defaults to 1.
:param device: The device to use. Defaults to None (cpu).
:param plugins: Plugins to be added. Defaults to None.
:param evaluator: (optional) instance of EvaluationPlugin for logging
and metric computations.
:param eval_every: the frequency of the calls to `eval` inside the
training loop.
if -1: no evaluation during training.
if 0: calls `eval` after the final epoch of each training
experience.
if >0: calls `eval` every `eval_every` epochs and at the end
of all the epochs for a single experience.
"""

siw = SIWPlugin(model, siw_layer_name, batch_size, num_workers)
if plugins is None:
plugins = [siw]
else:
plugins.append(siw)

super().__init__(
model, optimizer, criterion,
train_mb_size=train_mb_size, train_epochs=train_epochs,
eval_mb_size=eval_mb_size, device=device, plugins=plugins,
evaluator=evaluator, eval_every=eval_every)


class CoPE(BaseStrategy):

def __init__(self, model: Module, optimizer: Optimizer, criterion,
Expand Down Expand Up @@ -514,5 +561,6 @@ def __init__(self, model: Module, optimizer: Optimizer, criterion,
'GEM',
'EWC',
'SynapticIntelligence',
'SIW',
'CoPE'
]
155 changes: 155 additions & 0 deletions examples/siw_cifar100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from avalanche.benchmarks.classic import SplitCIFAR100
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from avalanche.training.strategies import Naive
from avalanche.training.plugins import SIWPlugin,\
EvaluationPlugin, StrategyPlugin
from avalanche.logging import InteractiveLogger
from avalanche.evaluation.metrics import accuracy_metrics
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch
import argparse
from torch.optim import lr_scheduler


class LRSchedulerPlugin(StrategyPlugin):
def __init__(self, lr_scheduler):
super().__init__()
self.lr_scheduler = lr_scheduler

def after_training_epoch(self, strategy: 'BaseStrategy', **kwargs):
self.lr_scheduler.step(strategy.loss.cpu().data.numpy())
lr = strategy.optimizer.param_groups[0]['lr']
print(f"\nlr = {lr}")


class SetIncrementalHyperParams(StrategyPlugin):
def __init__(self, inc_exp_epochs, inc_exp_patience, first_exp_lr,
lr_decay):
super().__init__()
self.inc_exp_epochs = inc_exp_epochs
self.inc_exp_patience = inc_exp_patience
self.first_exp_lr = first_exp_lr
self.lr_decay = lr_decay

def before_training_exp(self, strategy: 'BaseStrategy', **kwargs):
if strategy.experience.current_experience > 0: # incremental update
strategy.train_epochs = self.inc_exp_epochs
strategy.optimizer.param_groups[0]['lr'] = \
self.first_exp_lr / strategy.experience.current_experience
strategy.scheduler = LRSchedulerPlugin(
lr_scheduler.ReduceLROnPlateau(strategy.optimizer,
patience=self.inc_exp_patience,
factor=self.lr_decay))


def main(args):
# check if selected GPU is available or use CPU
assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0."
device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available()
and args.cuda >= 0 else "cpu")
print(f'Using device: {device}')
#############################################
model = torchvision.models.resnet18(num_classes=100).to(device)

# print to stdout
interactive_logger = InteractiveLogger()

eval_plugin = EvaluationPlugin(
accuracy_metrics(minibatch=False, epoch=True, experience=True,
stream=True),
loggers=[interactive_logger]
)

optimizer = SGD(model.parameters(), lr=args.first_exp_lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
criterion = CrossEntropyLoss()
scheduler = LRSchedulerPlugin(
lr_scheduler.ReduceLROnPlateau(optimizer,
patience=args.first_exp_patience,
factor=args.lr_decay))
incremental_params = SetIncrementalHyperParams(args.inc_exp_epochs,
args.inc_exp_patience,
args.first_exp_lr,
args.lr_decay)

siw = SIWPlugin(model, siw_layer_name=args.siw_layer_name,
batch_size=args.eval_batch_size,
num_workers=args.num_workers)

strategy = Naive(model, optimizer, criterion,
device=device, train_epochs=args.first_exp_epochs,
evaluator=eval_plugin,
plugins=[siw, scheduler, incremental_params],
train_mb_size=args.train_batch_size,
eval_mb_size=args.eval_batch_size)

normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409],
std=[0.2673, 0.2564, 0.2762])

train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])

test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize])

# scenario
scenario = SplitCIFAR100(n_experiences=10, return_task_id=False,
fixed_class_order=range(0, 100),
train_transform=train_transform,
eval_transform=test_transform)
# TRAINING LOOP
print('Starting experiment...')
results = []
for i, experience in enumerate(scenario.train_stream):
print("Start of experience: ", experience.current_experience)
strategy.train(experience, num_workers=args.num_workers)
print('Training completed')
print('Computing accuracy on the test set')
res = strategy.eval(scenario.test_stream[:i + 1],
num_workers=args.num_workers)
results.append(res)

print('Results = ' + str(results))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--first_exp_lr', type=float, default=0.1,
help='Learning rate for the first experience.')
parser.add_argument('--momentum', type=float, default=0.9,
help='Momentum')
parser.add_argument('--weight_decay', type=float, default=0.0005,
help='Weight decay')
parser.add_argument('--lr_decay', type=float, default=0.1,
help='LR decay')
parser.add_argument('--first_exp_patience', type=int, default=60,
help='Patience in the first experience')
parser.add_argument('--inc_exp_patience', type=int, default=15,
help='Patience in the incremental experiences')
parser.add_argument('--first_exp_epochs', type=int, default=300,
help='Number of epochs in the first experience.')
parser.add_argument('--inc_exp_epochs', type=int, default=70,
help='Number of epochs in each incremental experience.')
parser.add_argument('--train_batch_size', type=int, default=128,
help='Training batch size.')
parser.add_argument('--eval_batch_size', type=int, default=32,
help='Evaluation batch size.')
parser.add_argument('--num_workers', type=int, default=8,
help='Number of workers used to extract scores.')
parser.add_argument('--siw_layer_name', type=str, default='fc',
help='Name of the last fully connected layer.')
parser.add_argument('--cuda', type=int, default=1,
help='Specify GPU id to use. Use CPU if -1.')
args = parser.parse_args()

main(args)
Loading