From 5f9a031fd2be4f727835b4be61c6b5e7b2e65222 Mon Sep 17 00:00:00 2001 From: T B <58150584+tberckmann@users.noreply.github.com> Date: Mon, 23 Dec 2019 12:00:45 -0500 Subject: [PATCH] Allow user to specify dtype --- torchsummary/torchsummary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..bb33fe1 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,7 +6,7 @@ import numpy as np -def summary(model, input_size, batch_size=-1, device="cuda"): +def summary(model, input_size, batch_size=-1, device="cuda", force_dtype=None): def register_hook(module): @@ -51,6 +51,8 @@ def hook(module, input, output): dtype = torch.cuda.FloatTensor else: dtype = torch.FloatTensor + if force_dtype: + dtype = force_dtype # multiple inputs to the network if isinstance(input_size, tuple):