Skip to content

Commit

Permalink
updated a demo script for deep networks.
Browse files Browse the repository at this point in the history
Added the ability to plot eval and train values distinctivly.
  • Loading branch information
MLRichter committed Mar 7, 2021
1 parent 51df03a commit 79f7fdc
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 124 deletions.
2 changes: 1 addition & 1 deletion delve/torchcallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _check_stats(self, stats: list):
'dtrc',
'embed',
]
compatible = [stat in supported_stats for stat in stats]
compatible = [stat in supported_stats if not "_" in stat else stat.split("_")[0] in stats for stat in stats]
incompatible = [i for i, x in enumerate(compatible) if not x]
assert all(compatible), "Stat {} is not supported".format(
stats[incompatible[0]])
Expand Down
2 changes: 1 addition & 1 deletion delve/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.43'
__version__ = '0.1.44'
33 changes: 21 additions & 12 deletions delve/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,20 @@ def _look_for_stats(self):
if 'embed' in key:
embed = True
if sat:
self.stats.append('lsat')
self.stats.append('lsat_train')
self.stats.append('lsat_eval')
if idim:
self.stats.append('idim')
self.stats.append('idim_train')
self.stats.append('idim_eval')
if det:
self.stats.append('det')
self.stats.append('det_train')
self.stats.append('det_eval')
if trc:
self.stats.append('trc')
self.stats.append('trc_train')
self.stats.append('trc_eval')
if dtrc:
self.stats.append('dtrc')
self.stats.append('dtrc_train')
self.stats.append('dtrc_eval')
if embed:
self.sample_stats.append('embed')

Expand Down Expand Up @@ -344,7 +349,7 @@ def close(self):
pass


def extract_layer_stat(df, epoch=19, primary_metric=None, stat='saturation') -> Tuple[pd.DataFrame, float]:
def extract_layer_stat(df, epoch=19, primary_metric=None, stat='saturation', state_mode="train") -> Tuple[pd.DataFrame, float]:
"""
Extracts a specific statistic for a single epoch from a result dataframe as produced by the CSV-writer
:param df: The dataframe produced by a CSVWriter
Expand All @@ -356,7 +361,7 @@ def extract_layer_stat(df, epoch=19, primary_metric=None, stat='saturation') ->
"""
cols = list(df.columns)
train_cols = [col for col in cols if
'train' in col and not 'accuracy' in col and stat in col]
state_mode in col and not 'accuracy' in col and stat in col]
if not np.any(epoch == df.index.values):
raise ValueError(f'Epoch {epoch} could not be recoreded, dataframe has only the following indices: {df.index.values}')
epoch_df = df[df.index.values == epoch]
Expand All @@ -369,7 +374,7 @@ def plot_stat(df, stat, pm=-1, savepath='run.png', epoch=0, primary_metric=None,
line=True, scatter=True, ylim=(0, 1.0), alpha_line=.6, alpha_scatter=1.0, color_line=None,
color_scatter=None,
primary_metric_loc=(0.7, 0.8), show_col_label_x=True, show_col_label_y=True, show_grid=True, save=True,
samples=False):
samples=False, stat_mode="train"):
"""
:param df:
Expand Down Expand Up @@ -434,7 +439,7 @@ def plot_stat(df, stat, pm=-1, savepath='run.png', epoch=0, primary_metric=None,
plt.grid()
plt.tight_layout()
if save:
final_savepath = savepath.replace('.csv', f'{stat}_epoch_{epoch}.png')
final_savepath = savepath.replace('.csv', f'_{stat}_{stat_mode}_epoch_{epoch}.png')
print(final_savepath)
plt.savefig(final_savepath)
return ax
Expand All @@ -444,24 +449,28 @@ def plot_stat_level_from_results(savepath, epoch, stat, primary_metric=None, fon
scatter=True, ylim=(0, 1.0), alpha_line=.6, alpha_scatter=1.0, color_line=None,
color_scatter=None,
primary_metric_loc=(0.7, 0.8), show_col_label_x=True, show_col_label_y=True,
show_grid=True, save=True):
show_grid=True, save=True, stat_mode="train"):
df = pd.read_csv(savepath, sep=';')
if "_" in stat:
stat, stat_mode = stat.split("_")
if epoch == -1:
epoch = df.index.values[-1]

epoch_df, pm = extract_layer_stat(df, stat=STATMAP[stat], epoch=epoch, primary_metric=primary_metric)
epoch_df, pm = extract_layer_stat(df, stat=STATMAP[stat], epoch=epoch, primary_metric=primary_metric, state_mode=stat_mode)
ax = plot_stat(df=epoch_df, pm=pm, savepath=savepath, epoch=epoch, primary_metric=primary_metric, fontsize=fontsize,
figsize=figsize, stat=stat, ylim=None if not stat is 'lsat' else (0, 1.0), line=line, scatter=scatter,
alpha_line=alpha_line, alpha_scatter=alpha_scatter, color_line=color_line,
color_scatter=color_scatter,
primary_metric_loc=primary_metric_loc, show_col_label_x=show_col_label_x,
show_col_label_y=show_col_label_y,
show_grid=show_grid, save=save)
show_grid=show_grid, save=save, stat_mode=stat_mode)
return ax


def plot_scatter_from_results(savepath, epoch, stat, df):
if len(df) > 0:
if "_" in stat:
stat = stat.split("_")[0]
ax = plot_stat(df=df, savepath=savepath, epoch=epoch, stat=stat, line=False, save=True, samples=True, ylim=None)
return ax
else:
Expand Down
2 changes: 1 addition & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def forward(self, x):
x, y, model = x.to(device), y.to(device), model.to(device)

layers = [model.linear1, model.linear2]
stats = CheckLayerSat('regression/h{}'.format(h), save_to="csv", modules=layers, device=device)
stats = CheckLayerSat('regression/h{}'.format(h), save_to="plotcsv", modules=layers, device=device, stats=["lsat", "lsat_eval"])

loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
Expand Down
184 changes: 75 additions & 109 deletions example_deep.py
Original file line number Diff line number Diff line change
@@ -1,113 +1,79 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from tqdm import tqdm, trange

from delve import CheckLayerSat
from torch.cuda import is_available
from torch.nn import CrossEntropyLoss
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose
from torch.utils.data.dataloader import DataLoader
from torch.optim import Adam
from torchvision.models.vgg import vgg11

# setup compute device
from tqdm import tqdm

from delve.writers import CSVandPlottingWriter

if __name__ == "__main__":

device = "cuda:0" if is_available() else "cpu"

# Get some data
train_data = CIFAR10(root="./tmp", train=True,
download=True, transform=Compose([ToTensor()]))
test_data = CIFAR10(root="./tmp", train=False, download=True, transform=Compose([ToTensor()]))

train_loader = DataLoader(train_data, batch_size=1024,
shuffle=True, num_workers=6,
pin_memory=True)
test_loader = DataLoader(test_data, batch_size=1024,
shuffle=False, num_workers=6,
pin_memory=True)

# instantiate model
model = vgg11(num_classes=10).to(device)

# instantiate optimizer and loss
optimizer = Adam(params=model.parameters())
criterion = CrossEntropyLoss().to(device)

# initialize delve
tracker = CheckLayerSat("my_experiment", save_to="plotcsv", stats=["lsat"], modules=model, device=device)

# begin training
for epoch in range(10):
model.train()
for (images, labels) in tqdm(train_loader):
images, labels = images.to(device), labels.to(device)
prediction = model(images)
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 128

train_set = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=batch_size, shuffle=False, num_workers=2)


class Net(nn.Module):
def __init__(self, h2):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, h2)
self.fc3 = nn.Linear(h2, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1)
if not os.path.exists('convNet'):
os.mkdir('convNet')

epochs = 5

for h2 in [8, 32, 128]: # compare various hidden layer sizes
net = resnet18(pretrained=False, num_classes=10)#Net(h2=h2) # instantiate network with hidden layer size `h2`

net.to(device)
logging_dir = 'convNet/simpson_h2-{}'.format(h2)
stats = CheckLayerSat(savefile=logging_dir, save_to=['plot', 'csv', 'npy'], modules=net, include_conv=True, stats=['dtrc', 'trc', 'cov', 'idim', 'lsat', 'det'], max_samples=1024,
verbose=True, writer_args={}, conv_method='channelwise', device='cpu', initial_epoch=5, interpolation_downsampling=4, interpolation_strategy='nearest')

#net = nn.DataParallel(net, device_ids=['cuda:0', 'cuda:1'])
print(net)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

#stats.write( "CIFAR10 ConvNet - Changing fc2 - size {}".format(h2)) # optional

for epoch in range(epochs):
if epoch == 2:
stats.stop()
if epoch == 3:
stats.resume()
running_loss = 0.0
step = 0
loader = tqdm(
train_loader, leave=True,
position=0) # track step progress and loss - optional
for i, data in enumerate(loader):
step = epoch * len(loader) + i
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()

optimizer.step()

running_loss += loss.data
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1,
running_loss / 2000))
running_loss = 0.0

# update the training progress display
loader.set_description(desc='[%d/%d, %5d] loss: %.3f' %
(epoch + 1, epochs, i + 1, loss.data))
# display layer saturation levels

stats.add_scalar('epoch', epoch) # optional
stats.add_scalar('loss', running_loss.cpu().numpy()) # optional
stats.add_saturations()

loader.write('\n')
loader.close()
stats.close()
loss.backward()
optimizer.step()

total = 0
test_loss = 0
correct = 0
model.eval()
for (images, labels) in tqdm(test_loader):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)
correct += torch.sum((predicted == labels)).item()
test_loss += loss.item()

# add some additional metrics we want to keep track of
tracker.add_scalar("accuracy", correct / total)
tracker.add_scalar("loss", test_loss / total)

# add saturation to the mix
tracker.add_saturations()

# close the tracker to finish training
tracker.close()

0 comments on commit 79f7fdc

Please sign in to comment.