Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions src/transformers/models/lightglue/image_processing_lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,36 @@ def validate_and_format_image_pairs(images: ImageInput):

def _is_valid_image(image):
"""images is a PIL Image or a 3D array."""
return is_pil_image(image) or (
is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
)
if is_pil_image(image):
return True
if not is_valid_image(image):
return False
return get_image_type(image) != ImageType.PIL and len(image.shape) == 3

if isinstance(images, list):
if len(images) == 2 and all((_is_valid_image(image)) for image in images):
return images
if all(
isinstance(image_pair, list)
and len(image_pair) == 2
and all(_is_valid_image(image) for image in image_pair)
for image_pair in images
):
return [image for image_pair in images for image in image_pair]
n_images = len(images)
if n_images == 2:
# Evaluate _is_valid_image for both images only once, no generator/all overhead.
a, b = images[0], images[1]
if _is_valid_image(a) and _is_valid_image(b):
return images
# Determine if this is a list of pairs (and if all pairs are valid)
# Avoid nested all/generators for performance.
is_pairs = True
for image_pair in images:
if not (isinstance(image_pair, list) and len(image_pair) == 2):
is_pairs = False
break
a, b = image_pair[0], image_pair[1]
if not (_is_valid_image(a) and _is_valid_image(b)):
is_pairs = False
break
if is_pairs:
# Use in-place extension to a flat list for better performance and lower memory than list comprehension
result = []
for image_pair in images:
result.extend(image_pair)
return result
raise ValueError(error_message)


Expand Down