diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..ef06109 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -22,6 +22,17 @@ def hook(module, input, output): summary[m_key]["output_shape"] = [ [-1] + list(o.size())[1:] for o in output ] + ''' + Adding check for dictionaries in the output. Dictionary are faster + for lookup and makes for more readable code. + ''' + if isinstance(output, (dict)): + summary[m_key]["output_shape" ] = [ + [-1] + list(output[key].size())[1:] for key in + output.keys() + ] + print(summary[m_key]["output_shape"]) + else: summary[m_key]["output_shape"] = list(output.size()) summary[m_key]["output_shape"][0] = batch_size