Skip to content

Commit

Permalink
Update torch-summary to torchinfo
Browse files Browse the repository at this point in the history
torch-summary changed the package name to torchinfo so updating where relevant.
Should also reduce confusion with the older package named torchsummary.


Former-commit-id: 86a8a82
  • Loading branch information
Greg Pauloski committed Jul 19, 2021
1 parent c2cf215 commit 87aba67
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The provided example training scripts have additional package requirements.

### ImageNet and Cifar-10
```
$ pip install torch-summary tqdm torchvision
$ pip install torchinfo==1.5.2 tqdm torchvision
```

### Language Modeling
Expand Down
8 changes: 4 additions & 4 deletions examples/horovod_cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cnn_utils.engine as engine
import cnn_utils.optimizers as optimizers

from torchsummary import summary
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from utils import save_checkpoint

Expand Down Expand Up @@ -121,11 +121,11 @@ def main():
train_sampler, train_loader, _, val_loader = datasets.get_cifar(args)
model = models.get_model(args.model)

if args.verbose:
summary(model, (3, 32, 32))

device = 'cpu' if not args.cuda else 'cuda'
model.to(device)

if args.verbose:
summary(model, (args.batch_size, 3, 32, 32), device=device)

os.makedirs(args.log_dir, exist_ok=True)
args.checkpoint_format = os.path.join(args.log_dir, args.checkpoint_format)
Expand Down
1 change: 0 additions & 1 deletion examples/horovod_imagenet_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import cnn_utils.engine as engine
import cnn_utils.optimizers as optimizers

from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from utils import LabelSmoothLoss, save_checkpoint

Expand Down
8 changes: 4 additions & 4 deletions examples/torch_cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cnn_utils.engine as engine
import cnn_utils.optimizers as optimizers

from torchsummary import summary
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from utils import save_checkpoint

Expand Down Expand Up @@ -131,12 +131,12 @@ def main():
train_sampler, train_loader, _, val_loader = datasets.get_cifar(args)
model = models.get_model(args.model)

if args.verbose:
summary(model, (3, 32, 32))

device = 'cpu' if not args.cuda else 'cuda'
model.to(device)

if args.verbose:
summary(model, (args.batch_size, 3, 32, 32), device=device)

model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank])

Expand Down

0 comments on commit 87aba67

Please sign in to comment.