-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
updated a demo script for deep networks.
Added the ability to plot eval and train values distinctivly.
- Loading branch information
Showing
5 changed files
with
99 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = '0.1.43' | ||
__version__ = '0.1.44' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,113 +1,79 @@ | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from torchvision.models import resnet18 | ||
from tqdm import tqdm, trange | ||
|
||
from delve import CheckLayerSat | ||
from torch.cuda import is_available | ||
from torch.nn import CrossEntropyLoss | ||
from torchvision.datasets import CIFAR10 | ||
from torchvision.transforms import ToTensor, Compose | ||
from torch.utils.data.dataloader import DataLoader | ||
from torch.optim import Adam | ||
from torchvision.models.vgg import vgg11 | ||
|
||
# setup compute device | ||
from tqdm import tqdm | ||
|
||
from delve.writers import CSVandPlottingWriter | ||
|
||
if __name__ == "__main__": | ||
|
||
device = "cuda:0" if is_available() else "cpu" | ||
|
||
# Get some data | ||
train_data = CIFAR10(root="./tmp", train=True, | ||
download=True, transform=Compose([ToTensor()])) | ||
test_data = CIFAR10(root="./tmp", train=False, download=True, transform=Compose([ToTensor()])) | ||
|
||
train_loader = DataLoader(train_data, batch_size=1024, | ||
shuffle=True, num_workers=6, | ||
pin_memory=True) | ||
test_loader = DataLoader(test_data, batch_size=1024, | ||
shuffle=False, num_workers=6, | ||
pin_memory=True) | ||
|
||
# instantiate model | ||
model = vgg11(num_classes=10).to(device) | ||
|
||
# instantiate optimizer and loss | ||
optimizer = Adam(params=model.parameters()) | ||
criterion = CrossEntropyLoss().to(device) | ||
|
||
# initialize delve | ||
tracker = CheckLayerSat("my_experiment", save_to="plotcsv", stats=["lsat"], modules=model, device=device) | ||
|
||
# begin training | ||
for epoch in range(10): | ||
model.train() | ||
for (images, labels) in tqdm(train_loader): | ||
images, labels = images.to(device), labels.to(device) | ||
prediction = model(images) | ||
optimizer.zero_grad(set_to_none=True) | ||
with torch.cuda.amp.autocast(): | ||
outputs = model(images) | ||
_, predicted = torch.max(outputs.data, 1) | ||
|
||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
|
||
batch_size = 128 | ||
|
||
train_set = torchvision.datasets.CIFAR10( | ||
root='./data', train=True, download=True, transform=transform) | ||
train_loader = torch.utils.data.DataLoader( | ||
train_set, batch_size=batch_size, shuffle=True, num_workers=2) | ||
|
||
test_set = torchvision.datasets.CIFAR10( | ||
root='./data', train=False, download=True, transform=transform) | ||
test_loader = torch.utils.data.DataLoader( | ||
test_set, batch_size=batch_size, shuffle=False, num_workers=2) | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self, h2): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, h2) | ||
self.fc3 = nn.Linear(h2, 10) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
|
||
if __name__ == '__main__': | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
torch.manual_seed(1) | ||
if not os.path.exists('convNet'): | ||
os.mkdir('convNet') | ||
|
||
epochs = 5 | ||
|
||
for h2 in [8, 32, 128]: # compare various hidden layer sizes | ||
net = resnet18(pretrained=False, num_classes=10)#Net(h2=h2) # instantiate network with hidden layer size `h2` | ||
|
||
net.to(device) | ||
logging_dir = 'convNet/simpson_h2-{}'.format(h2) | ||
stats = CheckLayerSat(savefile=logging_dir, save_to=['plot', 'csv', 'npy'], modules=net, include_conv=True, stats=['dtrc', 'trc', 'cov', 'idim', 'lsat', 'det'], max_samples=1024, | ||
verbose=True, writer_args={}, conv_method='channelwise', device='cpu', initial_epoch=5, interpolation_downsampling=4, interpolation_strategy='nearest') | ||
|
||
#net = nn.DataParallel(net, device_ids=['cuda:0', 'cuda:1']) | ||
print(net) | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | ||
|
||
#stats.write( "CIFAR10 ConvNet - Changing fc2 - size {}".format(h2)) # optional | ||
|
||
for epoch in range(epochs): | ||
if epoch == 2: | ||
stats.stop() | ||
if epoch == 3: | ||
stats.resume() | ||
running_loss = 0.0 | ||
step = 0 | ||
loader = tqdm( | ||
train_loader, leave=True, | ||
position=0) # track step progress and loss - optional | ||
for i, data in enumerate(loader): | ||
step = epoch * len(loader) + i | ||
inputs, labels = data | ||
inputs, labels = inputs.to(device), labels.to(device) | ||
optimizer.zero_grad() | ||
outputs = net(inputs) | ||
loss = criterion(outputs, labels) | ||
loss.backward() | ||
|
||
optimizer.step() | ||
|
||
running_loss += loss.data | ||
if i % 2000 == 1999: # print every 2000 mini-batches | ||
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, | ||
running_loss / 2000)) | ||
running_loss = 0.0 | ||
|
||
# update the training progress display | ||
loader.set_description(desc='[%d/%d, %5d] loss: %.3f' % | ||
(epoch + 1, epochs, i + 1, loss.data)) | ||
# display layer saturation levels | ||
|
||
stats.add_scalar('epoch', epoch) # optional | ||
stats.add_scalar('loss', running_loss.cpu().numpy()) # optional | ||
stats.add_saturations() | ||
|
||
loader.write('\n') | ||
loader.close() | ||
stats.close() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
total = 0 | ||
test_loss = 0 | ||
correct = 0 | ||
model.eval() | ||
for (images, labels) in tqdm(test_loader): | ||
images, labels = images.to(device), labels.to(device) | ||
outputs = model(images) | ||
loss = criterion(outputs, labels) | ||
_, predicted = torch.max(outputs.data, 1) | ||
|
||
total += labels.size(0) | ||
correct += torch.sum((predicted == labels)).item() | ||
test_loss += loss.item() | ||
|
||
# add some additional metrics we want to keep track of | ||
tracker.add_scalar("accuracy", correct / total) | ||
tracker.add_scalar("loss", test_loss / total) | ||
|
||
# add saturation to the mix | ||
tracker.add_saturations() | ||
|
||
# close the tracker to finish training | ||
tracker.close() |