diff --git a/examples/README.md b/examples/README.md index 178020f8..786d6503 100644 --- a/examples/README.md +++ b/examples/README.md @@ -6,7 +6,7 @@ The provided example training scripts have additional package requirements. ### ImageNet and Cifar-10 ``` -$ pip install torch-summary tqdm torchvision +$ pip install torchinfo==1.5.2 tqdm torchvision ``` ### Language Modeling diff --git a/examples/horovod_cifar10_resnet.py b/examples/horovod_cifar10_resnet.py index bb7e2df5..4b2e446e 100644 --- a/examples/horovod_cifar10_resnet.py +++ b/examples/horovod_cifar10_resnet.py @@ -12,7 +12,7 @@ import cnn_utils.engine as engine import cnn_utils.optimizers as optimizers -from torchsummary import summary +from torchinfo import summary from torch.utils.tensorboard import SummaryWriter from utils import save_checkpoint @@ -121,11 +121,11 @@ def main(): train_sampler, train_loader, _, val_loader = datasets.get_cifar(args) model = models.get_model(args.model) - if args.verbose: - summary(model, (3, 32, 32)) - device = 'cpu' if not args.cuda else 'cuda' model.to(device) + + if args.verbose: + summary(model, (args.batch_size, 3, 32, 32), device=device) os.makedirs(args.log_dir, exist_ok=True) args.checkpoint_format = os.path.join(args.log_dir, args.checkpoint_format) diff --git a/examples/horovod_imagenet_resnet.py b/examples/horovod_imagenet_resnet.py index cb7bd862..bed4fdf6 100644 --- a/examples/horovod_imagenet_resnet.py +++ b/examples/horovod_imagenet_resnet.py @@ -13,7 +13,6 @@ import cnn_utils.engine as engine import cnn_utils.optimizers as optimizers -from torchsummary import summary from torch.utils.tensorboard import SummaryWriter from utils import LabelSmoothLoss, save_checkpoint diff --git a/examples/torch_cifar10_resnet.py b/examples/torch_cifar10_resnet.py index 01060d81..937f6b28 100644 --- a/examples/torch_cifar10_resnet.py +++ b/examples/torch_cifar10_resnet.py @@ -12,7 +12,7 @@ import cnn_utils.engine as engine import cnn_utils.optimizers as optimizers -from torchsummary import summary +from torchinfo import summary from torch.utils.tensorboard import SummaryWriter from utils import save_checkpoint @@ -131,12 +131,12 @@ def main(): train_sampler, train_loader, _, val_loader = datasets.get_cifar(args) model = models.get_model(args.model) - if args.verbose: - summary(model, (3, 32, 32)) - device = 'cpu' if not args.cuda else 'cuda' model.to(device) + if args.verbose: + summary(model, (args.batch_size, 3, 32, 32), device=device) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])