Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 22, 2025

📄 28% (0.28x) speedup for validate_and_format_image_pairs in src/transformers/models/lightglue/image_processing_lightglue.py

⏱️ Runtime : 2.58 milliseconds 2.02 milliseconds (best of 57 runs)

📝 Explanation and details

The optimized code achieves a 27% speedup by eliminating costly Python overhead from nested all() calls and generator expressions, which were the primary bottleneck consuming 91.4% of execution time.

Key optimizations:

  1. Short-circuit validation in _is_valid_image: The function now returns early for PIL images and invalid images, avoiding unnecessary get_image_type() calls and shape checks.

  2. Replaced nested all() with explicit loops: The original code used all(isinstance(image_pair, list) and len(image_pair) == 2 and all(_is_valid_image(image) for image in image_pair) for image_pair in images) which creates multiple generator objects and has significant function call overhead. The optimized version uses a simple for-loop with early breaking, reducing the overhead by directly checking conditions.

  3. Optimized list flattening: Replaced the list comprehension [image for image_pair in images for image in image_pair] with in-place list.extend() calls, which avoids creating intermediate lists and reduces memory allocations.

  4. Direct indexing for pairs: For the simple case of exactly 2 images, the code now directly accesses images[0] and images[1] instead of using generators with all().

Impact on workloads: Since this function is called in preprocess() and visualize_keypoint_matching() - both likely to be in hot paths for image processing pipelines - this optimization will provide meaningful performance improvements, especially for:

  • Large batches of image pairs (55.5% faster for 500 pairs)
  • Error cases with early validation failures (up to 74% faster)
  • Mixed PIL/numpy image types (25% faster)

The optimization is particularly effective for the transformer's LightGlue model preprocessing, where image validation happens before every inference or visualization operation.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 37 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import numpy as np

# imports
import pytest
from PIL import Image

# function to test
from transformers.image_utils import ImageType
from transformers.models.lightglue.image_processing_lightglue import validate_and_format_image_pairs


def is_pil_image(image):
    return isinstance(image, Image.Image)


def is_valid_image(image):
    # Accepts PIL images or 3D numpy arrays with shape (H, W, C)
    if is_pil_image(image):
        return True
    if isinstance(image, np.ndarray) and image.ndim == 3:
        return True
    return False


def get_image_type(image):
    if is_pil_image(image):
        return ImageType.PIL
    elif isinstance(image, np.ndarray):
        return ImageType.NUMPY
    else:
        return None


# unit tests

# ----------- Basic Test Cases ------------


def create_pil_image(color=(255, 0, 0)):
    # Create a simple 10x10 PIL image
    img = Image.new("RGB", (10, 10), color)
    return img


def create_np_image(shape=(10, 10, 3), value=0):
    # Create a simple 3D numpy array image
    arr = np.full(shape, value, dtype=np.uint8)
    return arr


def test_basic_pair_of_pil_images():
    # Two PIL images
    img1 = create_pil_image((255, 0, 0))
    img2 = create_pil_image((0, 255, 0))
    codeflash_output = validate_and_format_image_pairs([img1, img2])
    result = codeflash_output  # 3.43μs -> 2.44μs (40.5% faster)


def test_basic_pair_of_np_images():
    # Two 3D numpy arrays
    img1 = create_np_image()
    img2 = create_np_image(value=128)
    codeflash_output = validate_and_format_image_pairs([img1, img2])
    result = codeflash_output  # 8.96μs -> 7.87μs (13.8% faster)


def test_basic_list_of_pairs_pil_images():
    # List of pairs of PIL images
    img_pairs = [
        [create_pil_image((255, 0, 0)), create_pil_image((0, 255, 0))],
        [create_pil_image((0, 0, 255)), create_pil_image((255, 255, 0))],
    ]
    codeflash_output = validate_and_format_image_pairs(img_pairs)
    result = codeflash_output  # 6.79μs -> 5.29μs (28.4% faster)


def test_basic_list_of_pairs_np_images():
    # List of pairs of numpy images
    img_pairs = [
        [create_np_image(), create_np_image(value=128)],
        [create_np_image(value=255), create_np_image(value=64)],
    ]
    codeflash_output = validate_and_format_image_pairs(img_pairs)
    result = codeflash_output  # 12.7μs -> 10.8μs (17.2% faster)


def test_basic_mixed_pil_and_np_images_pair():
    # Pair of PIL and numpy image (should be valid)
    img1 = create_pil_image()
    img2 = create_np_image()
    codeflash_output = validate_and_format_image_pairs([img1, img2])
    result = codeflash_output  # 6.02μs -> 5.42μs (10.9% faster)


def test_basic_mixed_list_of_pairs():
    # List of pairs, each pair is PIL and numpy image
    img_pairs = [[create_pil_image(), create_np_image()], [create_np_image(), create_pil_image()]]
    codeflash_output = validate_and_format_image_pairs(img_pairs)
    result = codeflash_output  # 10.4μs -> 8.84μs (17.7% faster)


# ----------- Edge Test Cases ------------


def test_edge_single_image():
    # Single image should raise ValueError
    img = create_pil_image()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([img])  # 2.82μs -> 1.63μs (72.4% faster)


def test_edge_list_of_single_image_pairs():
    # List of single images (not pairs) should raise ValueError
    img_list = [[create_pil_image()]]
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(img_list)  # 2.55μs -> 1.62μs (57.3% faster)


def test_edge_pair_with_invalid_type():
    # Pair with invalid type (e.g. string)
    img1 = create_pil_image()
    img2 = "not an image"
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([img1, img2])  # 6.91μs -> 5.65μs (22.4% faster)


def test_edge_list_of_pairs_with_invalid_type():
    # List of pairs, one pair contains invalid type
    img_pairs = [[create_pil_image(), create_pil_image()], [create_np_image(), "not an image"]]
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(img_pairs)  # 11.0μs -> 9.02μs (22.4% faster)


def test_edge_pair_of_2d_np_arrays():
    # Pair of 2D numpy arrays (should be invalid)
    img1 = np.zeros((10, 10), dtype=np.uint8)
    img2 = np.ones((10, 10), dtype=np.uint8)
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([img1, img2])  # 7.14μs -> 6.18μs (15.6% faster)


def test_edge_pair_of_4d_np_arrays():
    # Pair of 4D numpy arrays (should be invalid)
    img1 = np.zeros((10, 10, 3, 1), dtype=np.uint8)
    img2 = np.ones((10, 10, 3, 1), dtype=np.uint8)
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([img1, img2])  # 6.80μs -> 5.76μs (18.0% faster)


def test_edge_list_of_pairs_with_2d_and_3d_np_arrays():
    # List of pairs, one pair contains a 2D array
    img_pairs = [[create_np_image(), create_np_image()], [np.zeros((10, 10), dtype=np.uint8), create_np_image()]]
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(img_pairs)  # 11.4μs -> 9.38μs (21.8% faster)


def test_edge_non_list_input():
    # Non-list input should raise ValueError
    img = create_pil_image()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(img)  # 1.13μs -> 1.05μs (7.32% faster)


def test_edge_tuple_input():
    # Tuple input should raise ValueError (only lists are accepted)
    img1 = create_pil_image()
    img2 = create_pil_image()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs((img1, img2))  # 1.16μs -> 1.03μs (12.8% faster)


def test_edge_list_of_pairs_with_extra_element():
    # List of pairs, one pair has three images
    img_pairs = [
        [create_pil_image(), create_pil_image()],
        [create_pil_image(), create_pil_image(), create_pil_image()],
    ]
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(img_pairs)  # 8.14μs -> 6.66μs (22.1% faster)


# ----------- Large Scale Test Cases ------------


def test_large_list_of_pairs_pil_images():
    # Large list of pairs of PIL images
    img_pairs = [[create_pil_image(), create_pil_image()] for _ in range(500)]
    codeflash_output = validate_and_format_image_pairs(img_pairs)
    result = codeflash_output  # 265μs -> 170μs (55.5% faster)
    for img in result:
        pass


def test_large_list_of_pairs_np_images():
    # Large list of pairs of numpy images
    img_pairs = [[create_np_image(), create_np_image()] for _ in range(500)]
    codeflash_output = validate_and_format_image_pairs(img_pairs)
    result = codeflash_output  # 1.13ms -> 936μs (20.6% faster)
    for img in result:
        pass


def test_large_list_of_mixed_pairs():
    # Large list of mixed pairs (PIL and numpy)
    img_pairs = []
    for i in range(500):
        if i % 2 == 0:
            img_pairs.append([create_pil_image(), create_np_image()])
        else:
            img_pairs.append([create_np_image(), create_pil_image()])
    codeflash_output = validate_and_format_image_pairs(img_pairs)
    result = codeflash_output  # 727μs -> 579μs (25.4% faster)
    for i, img in enumerate(result):
        pass


def test_large_pair_of_np_images():
    # Large pair of numpy images (each image is large, but only two images)
    img1 = create_np_image(shape=(512, 512, 3), value=10)
    img2 = create_np_image(shape=(512, 512, 3), value=20)
    codeflash_output = validate_and_format_image_pairs([img1, img2])
    result = codeflash_output  # 10.8μs -> 9.88μs (9.33% faster)


def test_large_list_of_pairs_with_one_invalid():
    # Large list of pairs, one pair is invalid
    img_pairs = [[create_pil_image(), create_pil_image()] for _ in range(499)]
    img_pairs.append(["not an image", create_pil_image()])
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(img_pairs)  # 249μs -> 161μs (54.7% faster)


# ----------- Mutation Sensitivity Test ------------


def test_mutation_sensitivity_wrong_return_flatten():
    # If function returns flattened list for pair input, should fail
    img1 = create_pil_image()
    img2 = create_pil_image()
    codeflash_output = validate_and_format_image_pairs([img1, img2])
    result = codeflash_output  # 3.38μs -> 2.52μs (34.0% faster)


def test_mutation_sensitivity_wrong_accept_tuple():
    # If function accepts tuple as valid input, should fail
    img1 = create_pil_image()
    img2 = create_pil_image()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs((img1, img2))  # 1.19μs -> 1.12μs (6.43% faster)


def test_mutation_sensitivity_wrong_accept_2d_array():
    # If function accepts 2D array, should fail
    img1 = np.zeros((10, 10), dtype=np.uint8)
    img2 = np.ones((10, 10), dtype=np.uint8)
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([img1, img2])  # 8.68μs -> 7.40μs (17.3% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
# imports
import pytest

from transformers.models.lightglue.image_processing_lightglue import validate_and_format_image_pairs


# Mock PIL Image for testing
class MockPILImage:
    def __init__(self, shape=(224, 224, 3)):
        self._shape = shape

    def __array__(self):
        return self

    @property
    def shape(self):
        return self._shape


# Mock 3D array (e.g., numpy-like)
class Mock3DArray:
    def __init__(self, shape=(224, 224, 3)):
        self.shape = shape


# Simulate ImageType enum
class ImageType:
    PIL = "PIL"
    ARRAY = "ARRAY"


# is_pil_image: True if instance of MockPILImage
def is_pil_image(image):
    return isinstance(image, MockPILImage)


# is_valid_image: True if has .shape and shape is tuple of length 3
def is_valid_image(image):
    return hasattr(image, "shape") and isinstance(image.shape, tuple) and len(image.shape) == 3


# get_image_type: PIL for MockPILImage, ARRAY for Mock3DArray
def get_image_type(image):
    if isinstance(image, MockPILImage):
        return ImageType.PIL
    elif isinstance(image, Mock3DArray):
        return ImageType.ARRAY
    else:
        return None


# ------------------- UNIT TESTS -------------------

# BASIC TEST CASES


def test_single_image_raises():
    # Single image should raise
    img = MockPILImage()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([img])  # 2.87μs -> 1.65μs (74.3% faster)


def test_pair_with_one_invalid_image_raises():
    # One valid, one invalid image
    img = MockPILImage()
    not_img = object()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([img, not_img])  # 7.40μs -> 6.05μs (22.4% faster)


def test_pair_with_invalid_shape_raises():
    # 3D array with wrong shape (not 3D)
    arr = Mock3DArray(shape=(224, 224))  # Only 2D
    arr2 = Mock3DArray()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([arr, arr2])  # 5.55μs -> 4.17μs (33.0% faster)


def test_nested_pair_with_invalid_image_raises():
    # List of pairs, one pair is invalid
    pairs = [[MockPILImage(), MockPILImage()], [Mock3DArray(), object()]]
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(pairs)  # 6.58μs -> 5.14μs (28.1% faster)


def test_pair_of_lists_instead_of_images_raises():
    # List of two lists (not images)
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([[1, 2], [3, 4]])  # 6.42μs -> 5.02μs (27.9% faster)


def test_list_of_pairs_with_non_list_pair_raises():
    # One pair is not a list
    pairs = [[MockPILImage(), MockPILImage()], Mock3DArray()]
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(pairs)  # 6.29μs -> 5.07μs (24.0% faster)


def test_list_of_pairs_with_pair_of_wrong_length_raises():
    # Pair has length != 2
    pairs = [[MockPILImage(), MockPILImage()], [Mock3DArray()]]
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(pairs)  # 6.43μs -> 4.83μs (33.3% faster)


def test_input_not_a_list_raises():
    # Input is not a list
    img = MockPILImage()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(img)  # 1.31μs -> 1.05μs (25.6% faster)


def test_pair_of_3d_array_with_pil_type_fails():
    # 3D array but get_image_type returns PIL (should fail)
    class FakeArray(Mock3DArray):
        pass

    arr = FakeArray()
    # Patch get_image_type to return PIL for this instance
    orig_get_image_type = get_image_type

    def fake_get_image_type(image):
        if isinstance(image, FakeArray):
            return ImageType.PIL
        return orig_get_image_type(image)

    # Patch in the function under test
    try:
        globals()["get_image_type"] = fake_get_image_type
        with pytest.raises(ValueError):
            validate_and_format_image_pairs([arr, arr])
    finally:
        globals()["get_image_type"] = orig_get_image_type


def test_pair_of_3d_array_with_shape_length_not_3_raises():
    # 3D array with shape of length 4
    arr1 = Mock3DArray(shape=(1, 2, 3, 4))
    arr2 = Mock3DArray()
    with pytest.raises(ValueError):
        validate_and_format_image_pairs([arr1, arr2])  # 5.54μs -> 4.42μs (25.4% faster)


# LARGE SCALE TEST CASES


def test_large_list_with_one_invalid_pair_raises():
    # 999 valid pairs, 1 invalid
    pairs = [[MockPILImage(), MockPILImage()] for _ in range(999)]
    pairs.append([MockPILImage(), object()])
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(pairs)  # 8.44μs -> 6.67μs (26.6% faster)


def test_large_flat_pair_list():
    # Flat list of 1000 valid images (should not be accepted, only pairs or list of pairs)
    imgs = [MockPILImage() for _ in range(1000)]
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(imgs)  # 2.83μs -> 1.71μs (65.8% faster)


def test_large_list_of_pairs_with_pair_of_length_3_raises():
    # One pair has length 3, rest are valid
    pairs = [[MockPILImage(), MockPILImage()] for _ in range(499)]
    pairs.append([MockPILImage(), MockPILImage(), MockPILImage()])
    with pytest.raises(ValueError):
        validate_and_format_image_pairs(pairs)  # 7.50μs -> 5.88μs (27.5% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-validate_and_format_image_pairs-mia63r9f and push.

Codeflash Static Badge

The optimized code achieves a **27% speedup** by eliminating costly Python overhead from nested `all()` calls and generator expressions, which were the primary bottleneck consuming 91.4% of execution time.

**Key optimizations:**

1. **Short-circuit validation in `_is_valid_image`**: The function now returns early for PIL images and invalid images, avoiding unnecessary `get_image_type()` calls and shape checks.

2. **Replaced nested `all()` with explicit loops**: The original code used `all(isinstance(image_pair, list) and len(image_pair) == 2 and all(_is_valid_image(image) for image in image_pair) for image_pair in images)` which creates multiple generator objects and has significant function call overhead. The optimized version uses a simple for-loop with early breaking, reducing the overhead by directly checking conditions.

3. **Optimized list flattening**: Replaced the list comprehension `[image for image_pair in images for image in image_pair]` with in-place `list.extend()` calls, which avoids creating intermediate lists and reduces memory allocations.

4. **Direct indexing for pairs**: For the simple case of exactly 2 images, the code now directly accesses `images[0]` and `images[1]` instead of using generators with `all()`.

**Impact on workloads**: Since this function is called in `preprocess()` and `visualize_keypoint_matching()` - both likely to be in hot paths for image processing pipelines - this optimization will provide meaningful performance improvements, especially for:
- Large batches of image pairs (55.5% faster for 500 pairs)  
- Error cases with early validation failures (up to 74% faster)
- Mixed PIL/numpy image types (25% faster)

The optimization is particularly effective for the transformer's LightGlue model preprocessing, where image validation happens before every inference or visualization operation.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 22, 2025 10:50
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant