diff --git a/torchsummary/tests/test_models/test_model.py b/torchsummary/tests/test_models/test_model.py index e4bce9d..9d4e20a 100644 --- a/torchsummary/tests/test_models/test_model.py +++ b/torchsummary/tests/test_models/test_model.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.parameter import Parameter + class SingleInputNet(nn.Module): def __init__(self): @@ -19,6 +21,7 @@ def forward(self, x): x = self.fc2(x) return F.log_softmax(x, dim=1) + class MultipleInputNet(nn.Module): def __init__(self): super(MultipleInputNet, self).__init__() @@ -36,6 +39,7 @@ def forward(self, x1, x2): x = torch.cat((x1, x2), 0) return F.log_softmax(x, dim=1) + class MultipleInputNetDifferentDtypes(nn.Module): def __init__(self): super(MultipleInputNetDifferentDtypes, self).__init__() @@ -54,3 +58,45 @@ def forward(self, x1, x2): # set x2 to FloatTensor x = torch.cat((x1, x2), 0) return F.log_softmax(x, dim=1) + + +class NestedNet(nn.Module): + def __init__(self): + super(NestedNet, self).__init__() + self.conv_block1 = ConvBlock(1, 10, 5) + self.conv_block2 = ConvBlock(10, 20, 5) + self.conv_drop = nn.Dropout2d(0.3) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(self.conv_block1(x)) + x = F.relu((self.conv_drop(self.conv_block2(x)))) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size) + self.bn = nn.BatchNorm2d(out_channels) + self.pool = nn.MaxPool2d(2, stride=2) + + def forward(self, x): + x = F.relu(self.conv(x)) + x = self.bn(x) + x = self.pool(x) + return x + + +class CustomModule(nn.Module): + def __init__(self): + super(CustomModule, self).__init__() + weight_tensor = torch.rand(50, 50) + self.W = Parameter(weight_tensor, requires_grad=True) + + def forward(self, x): + return torch.einsum("bij,jk->bik", x, self.W) diff --git a/torchsummary/tests/unit_tests/torchsummary_test.py b/torchsummary/tests/unit_tests/torchsummary_test.py index ec4b33e..7cdb465 100644 --- a/torchsummary/tests/unit_tests/torchsummary_test.py +++ b/torchsummary/tests/unit_tests/torchsummary_test.py @@ -1,11 +1,14 @@ import unittest from torchsummary import summary, summary_string -from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet, MultipleInputNetDifferentDtypes +from torchsummary.torchsummary import _build_summary_dict, _build_summary_string +from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet, \ + MultipleInputNetDifferentDtypes, NestedNet, CustomModule import torch gpu_if_available = "cuda:0" if torch.cuda.is_available() else "cpu" -class torchsummaryTests(unittest.TestCase): + +class TorchSummaryTests(unittest.TestCase): def test_single_input(self): model = SingleInputNet() input = (1, 28, 28) @@ -48,8 +51,34 @@ def test_multiple_input_types(self): self.assertEqual(total_params, 31120) self.assertEqual(trainable_params, 31120) + def test_recursive(self): + model = NestedNet() + input = (1, 28, 28) + summary = _build_summary_dict(model, [input], device='cpu') + summary_str, (total_params, trainable_params) = _build_summary_string(summary, [input]) + + self.assertListEqual(list(summary.keys()), ['Conv2d-1', 'BatchNorm2d-2', 'MaxPool2d-3', 'ConvBlock-4', + 'Conv2d-5', 'BatchNorm2d-6', 'MaxPool2d-7', 'ConvBlock-8', + 'Dropout2d-9', 'Linear-10', 'Linear-11', 'NestedNet-12']) + self.assertEqual(total_params, 21900) + self.assertEqual(trainable_params, 21900) + + summary = _build_summary_dict(model, [input], device='cpu', recurse=False) + summary_str, (total_params, trainable_params) = _build_summary_string(summary, [input]) + self.assertListEqual(list(summary.keys()), ['ConvBlock-1', 'ConvBlock-2', 'Dropout2d-3', 'Linear-4', + 'Linear-5', 'NestedNet-6']) + self.assertEqual(total_params, 21900) + self.assertEqual(trainable_params, 21900) + + def test_custom_module(self): + model = CustomModule() + input = (1, 50) + total_params, trainable_params = summary(model, input, device='cpu') + self.assertEqual(total_params, 2500) + self.assertEqual(trainable_params, 2500) + -class torchsummarystringTests(unittest.TestCase): +class TorchSummaryStringTests(unittest.TestCase): def test_single_input(self): model = SingleInputNet() input = (1, 28, 28) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..ac95efe 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -1,25 +1,31 @@ import torch import torch.nn as nn -from torch.autograd import Variable from collections import OrderedDict import numpy as np -def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): - result, params_info = summary_string( - model, input_size, batch_size, device, dtypes) +def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None, recurse=True): + result, params_info = summary_string(model, input_size, batch_size, device, dtypes, recurse) print(result) return params_info -def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): +def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None, recurse=True): + # multiple inputs to the network + if isinstance(input_size, tuple): + input_size = [input_size] + + summary = _build_summary_dict( + model, input_size, batch_size, device, dtypes, recurse) + return _build_summary_string(summary, input_size, batch_size) + + +def _build_summary_dict(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None, recurse=True): if dtypes == None: dtypes = [torch.FloatTensor]*len(input_size) - summary_str = '' - def register_hook(module): def hook(module, input, output): class_name = str(module.__class__).split(".")[-1].split("'")[0] @@ -37,13 +43,14 @@ def hook(module, input, output): summary[m_key]["output_shape"] = list(output.size()) summary[m_key]["output_shape"][0] = batch_size - params = 0 - if hasattr(module, "weight") and hasattr(module.weight, "size"): - params += torch.prod(torch.LongTensor(list(module.weight.size()))) - summary[m_key]["trainable"] = module.weight.requires_grad - if hasattr(module, "bias") and hasattr(module.bias, "size"): - params += torch.prod(torch.LongTensor(list(module.bias.size()))) - summary[m_key]["nb_params"] = params + nb_params = 0 + trainable_params = 0 + for name, p in module.named_parameters(): + params = torch.numel(p) + nb_params += params + trainable_params += params if p.requires_grad else 0 + summary[m_key]["nb_params"] = nb_params + summary[m_key]["trainable"] = trainable_params if ( not isinstance(module, nn.Sequential) @@ -51,10 +58,6 @@ def hook(module, input, output): ): hooks.append(module.register_forward_hook(hook)) - # multiple inputs to the network - if isinstance(input_size, tuple): - input_size = [input_size] - # batch_size of 2 for batchnorm x = [torch.rand(2, *in_size).type(dtype).to(device=device) for in_size, dtype in zip(input_size, dtypes)] @@ -64,7 +67,11 @@ def hook(module, input, output): hooks = [] # register hook - model.apply(register_hook) + if recurse: + model.apply(register_hook) + else: + [register_hook(m) for m in model.children()] + register_hook(model) # make a forward pass # print(x.shape) @@ -74,6 +81,12 @@ def hook(module, input, output): for h in hooks: h.remove() + return summary + + +def _build_summary_string(summary, input_size, batch_size=-1): + + summary_str = '' summary_str += "----------------------------------------------------------------" + "\n" line_new = "{:>20} {:>25} {:>15}".format( "Layer (type)", "Output Shape", "Param #") @@ -89,12 +102,11 @@ def hook(module, input, output): str(summary[layer]["output_shape"]), "{0:,}".format(summary[layer]["nb_params"]), ) - total_params += summary[layer]["nb_params"] + total_params = summary[layer]["nb_params"] total_output += np.prod(summary[layer]["output_shape"]) if "trainable" in summary[layer]: - if summary[layer]["trainable"] == True: - trainable_params += summary[layer]["nb_params"] + trainable_params = summary[layer]["trainable"] summary_str += line_new + "\n" # assume 4 bytes/number (float on cuda).