-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Optimization_Suite.py
127 lines (108 loc) · 4.48 KB
/
Optimization_Suite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import optuna
from torch.quantization import quantize_dynamic
import torch
from typing import Dict, Any, Optional
import pytorch_lightning as pl
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
import numpy as np
from pytorch_lightning.callbacks import EarlyStopping
def quantize_model(model: torch.nn.Module) -> torch.nn.Module:
"""Dynamic quantization of the model"""
return quantize_dynamic (
model,
{torch.nn.Linear},
dtype=torch.qint8
)
class ModelOptimizer:
def __init__(self, model_class, data_loader, config):
self.model_class = model_class
self.data_loader = data_loader
self.base_config = config
self.best_trial = None
def optimize_hyperparameters(self, n_trials: int = 20) -> Dict [str, Any]:
study = optuna.create_study (direction="minimize")
study.optimize (self._objective, n_trials=n_trials)
self.best_trial = study.best_trial
return study.best_params
def _objective(self, trial) -> float:
# Hyperparameter search space
config = self.base_config.copy ()
config.update ({
'learning_rate': trial.suggest_float ('learning_rate', 1e-5, 1e-1, log=True),
'batch_size': trial.suggest_int ('batch_size', 16, 128, step=16),
'optimizer': trial.suggest_categorical ('optimizer', ['adam', 'sgd', 'adamw']),
'weight_decay': trial.suggest_float ('weight_decay', 1e-5, 1e-2, log=True)
})
# Training with validation
model = self.model_class (config)
trainer = pl.Trainer (
max_epochs=5, # Short training for hyperparameter search
callbacks=[EarlyStopping (monitor='val_loss', patience=3)],
enable_progress_bar=False
)
trainer.fit (model, self.data_loader)
return trainer.callback_metrics ['val_loss'].item ()
def get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer:
if self.best_trial is None:
raise ValueError ("Run optimize_hyperparameters first")
opt_name = self.best_trial.params ['optimizer']
lr = self.best_trial.params ['learning_rate']
weight_decay = self.best_trial.params ['weight_decay']
optimizers = {
'adam': torch.optim.Adam,
'sgd': torch.optim.SGD,
'adamw': torch.optim.AdamW
}
return optimizers [opt_name] (
model.parameters (),
lr=lr,
weight_decay=weight_decay
)
def get_scheduler(self, optimizer: torch.optim.Optimizer) -> Optional [torch.optim.lr_scheduler._LRScheduler]:
scheduler_type = self.base_config.get ('scheduler', 'cosine')
if scheduler_type == 'plateau':
return ReduceLROnPlateau (
optimizer,
mode='min',
factor=0.1,
patience=5,
verbose=True
)
elif scheduler_type == 'cosine':
return CosineAnnealingLR (
optimizer,
T_max=self.base_config.get ('epochs', 100),
eta_min=1e-6
)
return None
@staticmethod
def apply_pruning(model: torch.nn.Module, amount: float = 0.3) -> torch.nn.Module:
"""Prune model weights"""
for name, module in model.named_modules ():
if isinstance (module, torch.nn.Linear):
torch.nn.utils.prune.l1_unstructured (
module,
name='weight',
amount=amount
)
return model
class MemoryTracker:
@staticmethod
def get_model_size(model: torch.nn.Module) -> float:
"""Get model size in MB"""
param_size = 0
for param in model.parameters ():
param_size += param.nelement () * param.element_size ()
buffer_size = 0
for buffer in model.buffers ():
buffer_size += buffer.nelement () * buffer.element_size ()
size_all_mb = (param_size + buffer_size) / 1024 ** 2
return size_all_mb
@staticmethod
def log_memory_stats(model: torch.nn.Module, phase: str = ""):
"""Log memory statistics"""
if torch.cuda.is_available ():
print (f"\n{phase} Memory Stats:")
print (f"Allocated: {torch.cuda.memory_allocated () / 1024 ** 2:.2f}MB")
print (f"Cached: {torch.cuda.memory_reserved () / 1024 ** 2:.2f}MB")
print (f"Model Size: {MemoryTracker.get_model_size (model):.2f}MB")