Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions src/transformers/models/lightglue/image_processing_lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,41 +361,59 @@ def post_process_keypoint_matching(
`list[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
of the pair, the matching scores and the matching indices.
"""
if outputs.mask.shape[0] != len(target_sizes):

mask = outputs.mask
if mask.shape[0] != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
if not all(len(target_size) == 2 for target_size in target_sizes):
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")

if isinstance(target_sizes, list):
image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device)
image_pair_sizes = torch.as_tensor(target_sizes, device=mask.device)
else:
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
raise ValueError(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
image_pair_sizes = target_sizes

keypoints = outputs.keypoints.clone()
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
keypoints = keypoints.to(torch.int32)
keypoints = outputs.keypoints
# Instead of .clone, use .mul_() on a clone to avoid extra memory if possible
scaled_keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
# .to(torch.int32) will create a new tensor automatically, so no extra .clone needed
scaled_keypoints = scaled_keypoints.to(torch.int32)

# Precompute slices for efficiency
mask0_all = mask[:, 0] > 0
mask1_all = mask[:, 1] > 0
matches_all = outputs.matches[:, 0]
scores_all = outputs.matching_scores[:, 0]
keypoints0_all = scaled_keypoints[:, 0]
keypoints1_all = scaled_keypoints[:, 1]

results = []
for mask_pair, keypoints_pair, matches, scores in zip(
outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0]
):
mask0 = mask_pair[0] > 0
mask1 = mask_pair[1] > 0
keypoints0 = keypoints_pair[0][mask0]
keypoints1 = keypoints_pair[1][mask1]
matches0 = matches[mask0]
scores0 = scores[mask0]
for i in range(mask.shape[0]):
mask0 = mask0_all[i]
mask1 = mask1_all[i]
keypoints0 = keypoints0_all[i][mask0]
keypoints1 = keypoints1_all[i][mask1]
matches0 = matches_all[i][mask0]
scores0 = scores_all[i][mask0]

# Filter out matches with low scores

# Filter out matches with low scores
valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1)

matched_keypoints0 = keypoints0[valid_matches]
matched_keypoints1 = keypoints1[matches0[valid_matches]]
matching_scores = scores0[valid_matches]
# Avoid selecting an empty index if no valid matches to prevent torch indexing error
matched_indices = matches0[valid_matches]
if matched_indices.numel() > 0:
matched_keypoints1 = keypoints1[matched_indices]
matching_scores = scores0[valid_matches]
else:
matched_keypoints1 = keypoints1.new_empty((0, keypoints1.shape[1]))
matching_scores = scores0.new_empty((0,))

results.append(
{
Expand Down