From 69c9598edc7008859063617a3a81b32c52ab8562 Mon Sep 17 00:00:00 2001 From: Bilal Alsallakh Date: Mon, 15 May 2023 16:38:46 -0700 Subject: [PATCH] unsure the model is in eval() mode Avoids arbitrary weight changes / dropouts, which leads to discrepancy when executing twice. --- inference.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/inference.py b/inference.py index f018fe3..80721b6 100644 --- a/inference.py +++ b/inference.py @@ -27,14 +27,13 @@ def inference(img_path: Path, img_size: tuple[int, int], # Prepare model vit_pose = ViTPose(model_cfg) - ckpt = torch.load(ckpt_path) if 'state_dict' in ckpt: vit_pose.load_state_dict(ckpt['state_dict']) else: vit_pose.load_state_dict(ckpt) - vit_pose.to(device) + vit_pose.to(device).eval() print(f">>> Model loaded: {ckpt_path}") # Prepare input data @@ -92,4 +91,4 @@ def inference(img_path: Path, img_size: tuple[int, int], print(img_path) keypoints = inference(img_path=img_path, img_size=img_size, model_cfg=model_cfg, ckpt_path=CKPT_PATH, device=torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu'), - save_result=True) \ No newline at end of file + save_result=True)