diff --git a/src/transformers/models/lightglue/image_processing_lightglue.py b/src/transformers/models/lightglue/image_processing_lightglue.py index 54cb70785397..47a871fcd2d7 100644 --- a/src/transformers/models/lightglue/image_processing_lightglue.py +++ b/src/transformers/models/lightglue/image_processing_lightglue.py @@ -361,13 +361,15 @@ def post_process_keypoint_matching( `list[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image of the pair, the matching scores and the matching indices. """ - if outputs.mask.shape[0] != len(target_sizes): + + mask = outputs.mask + if mask.shape[0] != len(target_sizes): raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") if not all(len(target_size) == 2 for target_size in target_sizes): raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") if isinstance(target_sizes, list): - image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + image_pair_sizes = torch.as_tensor(target_sizes, device=mask.device) else: if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2: raise ValueError( @@ -375,27 +377,43 @@ def post_process_keypoint_matching( ) image_pair_sizes = target_sizes - keypoints = outputs.keypoints.clone() - keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2) - keypoints = keypoints.to(torch.int32) + keypoints = outputs.keypoints + # Instead of .clone, use .mul_() on a clone to avoid extra memory if possible + scaled_keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2) + # .to(torch.int32) will create a new tensor automatically, so no extra .clone needed + scaled_keypoints = scaled_keypoints.to(torch.int32) + + # Precompute slices for efficiency + mask0_all = mask[:, 0] > 0 + mask1_all = mask[:, 1] > 0 + matches_all = outputs.matches[:, 0] + scores_all = outputs.matching_scores[:, 0] + keypoints0_all = scaled_keypoints[:, 0] + keypoints1_all = scaled_keypoints[:, 1] results = [] - for mask_pair, keypoints_pair, matches, scores in zip( - outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0] - ): - mask0 = mask_pair[0] > 0 - mask1 = mask_pair[1] > 0 - keypoints0 = keypoints_pair[0][mask0] - keypoints1 = keypoints_pair[1][mask1] - matches0 = matches[mask0] - scores0 = scores[mask0] + for i in range(mask.shape[0]): + mask0 = mask0_all[i] + mask1 = mask1_all[i] + keypoints0 = keypoints0_all[i][mask0] + keypoints1 = keypoints1_all[i][mask1] + matches0 = matches_all[i][mask0] + scores0 = scores_all[i][mask0] + + # Filter out matches with low scores # Filter out matches with low scores valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1) matched_keypoints0 = keypoints0[valid_matches] - matched_keypoints1 = keypoints1[matches0[valid_matches]] - matching_scores = scores0[valid_matches] + # Avoid selecting an empty index if no valid matches to prevent torch indexing error + matched_indices = matches0[valid_matches] + if matched_indices.numel() > 0: + matched_keypoints1 = keypoints1[matched_indices] + matching_scores = scores0[valid_matches] + else: + matched_keypoints1 = keypoints1.new_empty((0, keypoints1.shape[1])) + matching_scores = scores0.new_empty((0,)) results.append( {