Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 22, 2025

📄 28% (0.28x) speedup for shift_tokens_right in src/transformers/models/mvp/modeling_mvp.py

⏱️ Runtime : 2.63 milliseconds 2.06 milliseconds (best of 86 runs)

📝 Explanation and details

The optimized code achieves a 27% speedup through two key performance improvements:

1. Early validation check: Moving the pad_token_id is None check to the beginning eliminates wasted computation when inputs are invalid, providing dramatic speedups (up to 4581% in error cases).

2. More efficient tensor allocation: Replacing input_ids.new_zeros() with torch.empty_like() eliminates unnecessary zero-initialization since the tensor gets completely overwritten anyway. This reduces memory bandwidth overhead.

3. Removing unnecessary .clone(): The original code used .clone() when copying input_ids[:, :-1], which creates an extra tensor copy. The optimized version removes this since we're already writing to a new tensor.

Performance impact by workload:

  • Small tensors (basic cases): 11-17% speedup from reduced allocation overhead
  • Large tensors (1000x100): Up to 48% speedup as memory bandwidth savings compound
  • Error cases: Massive 4581% speedup from early validation

Hot path significance: Based on the function references, shift_tokens_right is called during model forward passes in both base model and conditional generation scenarios. It's used to prepare decoder inputs from either input_ids or labels, making it a critical path during training and inference where these optimizations will have meaningful impact on overall model performance.

The optimizations are particularly effective for transformer workloads that process large batches with long sequences, which is typical in modern NLP applications.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 37 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
import torch  # required for tensor operations

from transformers.models.mvp.modeling_mvp import shift_tokens_right


# unit tests

# ------------------ BASIC TEST CASES ------------------


def test_basic_shift_right():
    # Test shifting a simple 2x5 tensor
    input_ids = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
    pad_token_id = 0
    decoder_start_token_id = 99
    expected = torch.tensor([[99, 1, 2, 3, 4], [99, 6, 7, 8, 9]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 59.0μs -> 51.7μs (14.1% faster)


def test_basic_pad_token_id_replacement():
    # Test that -100 is replaced by pad_token_id after shift
    input_ids = torch.tensor([[1, -100, 3, 4, 5]])
    pad_token_id = 0
    decoder_start_token_id = 99
    # After shift: [99, 1, -100, 3, 4] -> -100 replaced by 0
    expected = torch.tensor([[99, 1, 0, 3, 4]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 53.6μs -> 47.8μs (12.3% faster)


def test_basic_single_row():
    # Test shifting a single row tensor
    input_ids = torch.tensor([[1, 2, 3]])
    pad_token_id = 0
    decoder_start_token_id = 42
    expected = torch.tensor([[42, 1, 2]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.6μs -> 47.1μs (11.8% faster)


def test_basic_single_column():
    # Test shifting a single column tensor (should handle shape [N, 1])
    input_ids = torch.tensor([[5], [6]])
    pad_token_id = 0
    decoder_start_token_id = 7
    expected = torch.tensor([[7], [7]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.2μs -> 46.4μs (12.5% faster)


def test_basic_pad_token_id_is_nonzero():
    # Test with a nonzero pad_token_id
    input_ids = torch.tensor([[1, -100, 3]])
    pad_token_id = 123
    decoder_start_token_id = 99
    expected = torch.tensor([[99, 1, 123]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.3μs -> 47.0μs (11.3% faster)


# ------------------ EDGE TEST CASES ------------------


def test_edge_one_dim_zero_rows():
    # Test shifting a tensor with zero rows (shape [0, 5])
    input_ids = torch.empty((0, 5), dtype=torch.long)
    pad_token_id = 0
    decoder_start_token_id = 1
    expected = torch.empty((0, 5), dtype=torch.long)
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 63.8μs -> 57.3μs (11.4% faster)


def test_edge_pad_token_id_none():
    # Test that ValueError is raised when pad_token_id is None
    input_ids = torch.tensor([[1, 2, 3]])
    decoder_start_token_id = 42
    with pytest.raises(ValueError):
        shift_tokens_right(input_ids, None, decoder_start_token_id)  # 43.5μs -> 929ns (4581% faster)


def test_edge_negative_values():
    # Test tensor containing negative values other than -100
    input_ids = torch.tensor([[1, -5, -100, 4]])
    pad_token_id = 0
    decoder_start_token_id = 99
    # After shift: [99, 1, -5, -100] -> -100 replaced by 0
    expected = torch.tensor([[99, 1, -5, 0]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 53.6μs -> 52.1μs (2.96% faster)


def test_edge_all_minus_100():
    # Test tensor where all values are -100
    input_ids = torch.full((2, 4), -100, dtype=torch.long)
    pad_token_id = 77
    decoder_start_token_id = 88
    # After shift: first column is decoder_start_token_id, rest are -100, which become pad_token_id
    expected = torch.tensor([[88, 77, 77, 77], [88, 77, 77, 77]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 51.7μs -> 46.0μs (12.3% faster)


def test_edge_decoder_start_token_id_is_minus_100():
    # Test if decoder_start_token_id is -100, should be replaced by pad_token_id
    input_ids = torch.tensor([[1, 2, 3]])
    pad_token_id = 0
    decoder_start_token_id = -100
    # After shift: [-100, 1, 2] -> -100 replaced by 0
    expected = torch.tensor([[0, 1, 2]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 53.0μs -> 46.9μs (13.2% faster)


def test_edge_decoder_start_token_id_is_pad_token_id():
    # Test if decoder_start_token_id == pad_token_id
    input_ids = torch.tensor([[1, 2, 3]])
    pad_token_id = 5
    decoder_start_token_id = 5
    expected = torch.tensor([[5, 1, 2]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.6μs -> 46.9μs (12.2% faster)


def test_edge_input_ids_dtype():
    # Test that output dtype matches input dtype
    input_ids = torch.tensor([[1, 2, 3]], dtype=torch.int64)
    pad_token_id = 0
    decoder_start_token_id = 99
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.6μs -> 47.2μs (11.2% faster)


def test_edge_input_ids_not_2d():
    # Test that non-2D input raises an error (should fail if input is not [N, L])
    input_ids = torch.tensor([1, 2, 3])
    pad_token_id = 0
    decoder_start_token_id = 99
    # Function expects 2D, so this should raise IndexError
    with pytest.raises(IndexError):
        shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)  # 46.7μs -> 44.2μs (5.66% faster)


# ------------------ LARGE SCALE TEST CASES ------------------


def test_large_scale_shift_right():
    # Test shifting a large tensor (1000x50)
    N, L = 1000, 50
    input_ids = torch.arange(1, N * L + 1, dtype=torch.long).reshape(N, L)
    pad_token_id = 0
    decoder_start_token_id = 101
    # Build expected output
    expected = torch.zeros((N, L), dtype=torch.long)
    expected[:, 0] = decoder_start_token_id
    expected[:, 1:] = input_ids[:, :-1]
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 153μs -> 108μs (41.8% faster)


def test_large_scale_with_minus_100():
    # Test large tensor with some -100 values
    N, L = 500, 30
    input_ids = torch.randint(-100, 100, (N, L), dtype=torch.long)
    pad_token_id = 999
    decoder_start_token_id = 888
    # After shift, -100 should be replaced by pad_token_id
    shifted = torch.zeros((N, L), dtype=torch.long)
    shifted[:, 0] = decoder_start_token_id
    shifted[:, 1:] = input_ids[:, :-1]
    shifted[shifted == -100] = pad_token_id
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 58.5μs -> 44.1μs (32.7% faster)


def test_large_scale_all_pad():
    # Test tensor where all values are pad_token_id
    N, L = 1000, 10
    pad_token_id = 123
    decoder_start_token_id = 456
    input_ids = torch.full((N, L), pad_token_id, dtype=torch.long)
    expected = torch.full((N, L), pad_token_id, dtype=torch.long)
    expected[:, 0] = decoder_start_token_id
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 60.4μs -> 46.7μs (29.4% faster)


def test_large_scale_randomized():
    # Test randomized input for robustness
    N, L = 999, 20
    input_ids = torch.randint(-100, 1000, (N, L), dtype=torch.long)
    pad_token_id = 0
    decoder_start_token_id = 42
    shifted = torch.zeros((N, L), dtype=torch.long)
    shifted[:, 0] = decoder_start_token_id
    shifted[:, 1:] = input_ids[:, :-1]
    shifted[shifted == -100] = pad_token_id
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 70.5μs -> 52.3μs (34.9% faster)


def test_large_scale_memory_limit():
    # Test that tensor size does not exceed 100MB
    N, L = 1000, 100  # 1000*100*8 bytes = 800,000 bytes = ~0.8MB
    input_ids = torch.randint(0, 10000, (N, L), dtype=torch.long)
    pad_token_id = 1
    decoder_start_token_id = 2
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 323μs -> 217μs (48.6% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
import torch

from transformers.models.mvp.modeling_mvp import shift_tokens_right


# unit tests

# -------------------- Basic Test Cases --------------------


def test_basic_shift():
    # Test with a simple 2x4 tensor
    input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
    pad_token_id = 0
    decoder_start_token_id = 101
    # Expected: first column is decoder_start_token_id, rest are shifted right
    expected = torch.tensor([[101, 1, 2, 3], [101, 5, 6, 7]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 59.6μs -> 51.9μs (15.0% faster)


def test_basic_shift_with_pad():
    # Test with pad_token_id present in input_ids
    input_ids = torch.tensor([[0, 2, 0, 4], [5, 0, 7, 0]])
    pad_token_id = 0
    decoder_start_token_id = 99
    expected = torch.tensor([[99, 0, 2, 0], [99, 5, 0, 7]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 55.5μs -> 47.7μs (16.3% faster)


def test_basic_shift_with_negative100():
    # Test with -100 present in input_ids, should be replaced by pad_token_id
    input_ids = torch.tensor([[1, -100, 3, 4], [5, 6, -100, 8]])
    pad_token_id = 0
    decoder_start_token_id = 101
    expected = torch.tensor([[101, 1, 0, 3], [101, 5, 6, 0]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 55.3μs -> 47.4μs (16.6% faster)


def test_basic_shift_single_row():
    # Test with a single row tensor
    input_ids = torch.tensor([[10, 11, 12, 13]])
    pad_token_id = 0
    decoder_start_token_id = 99
    expected = torch.tensor([[99, 10, 11, 12]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.4μs -> 44.7μs (17.2% faster)


def test_basic_shift_single_column():
    # Test with a single column tensor
    input_ids = torch.tensor([[10], [20], [30]])
    pad_token_id = 0
    decoder_start_token_id = 99
    expected = torch.tensor([[99], [99], [99]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.2μs -> 46.3μs (12.8% faster)


# -------------------- Edge Test Cases --------------------


def test_empty_tensor():
    # Test with an empty tensor (0 rows, N columns)
    input_ids = torch.empty((0, 4), dtype=torch.long)
    pad_token_id = 0
    decoder_start_token_id = 101
    expected = torch.empty((0, 4), dtype=torch.long)
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 54.0μs -> 46.7μs (15.7% faster)


def test_pad_token_id_none():
    # Test with pad_token_id=None, should raise ValueError
    input_ids = torch.tensor([[1, 2, 3, 4]])
    decoder_start_token_id = 101
    with pytest.raises(ValueError):
        shift_tokens_right(input_ids, None, decoder_start_token_id)  # 50.5μs -> 950ns (5219% faster)


def test_negative_pad_token_id():
    # Test with negative pad_token_id, should work as expected
    input_ids = torch.tensor([[1, -100, 3, 4]])
    pad_token_id = -1
    decoder_start_token_id = 101
    expected = torch.tensor([[101, 1, -1, 3]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 56.2μs -> 56.0μs (0.318% faster)


def test_all_negative100():
    # Test with all values -100, should be replaced by pad_token_id
    input_ids = torch.full((2, 3), -100, dtype=torch.long)
    pad_token_id = 0
    decoder_start_token_id = 101
    expected = torch.full((2, 3), pad_token_id, dtype=torch.long)
    expected[:, 0] = decoder_start_token_id
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 37.3μs -> 30.6μs (22.0% faster)


def test_large_decoder_start_token_id():
    # Test with a large decoder_start_token_id
    input_ids = torch.tensor([[1, 2, 3]])
    pad_token_id = 0
    decoder_start_token_id = 999999
    expected = torch.tensor([[999999, 1, 2]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.8μs -> 47.2μs (11.9% faster)


def test_input_ids_dtype():
    # Test with input_ids of type int32 (should work)
    input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.int32)
    pad_token_id = 0
    decoder_start_token_id = 101
    expected = torch.tensor([[101, 1, 2, 3]], dtype=torch.int32)
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 58.9μs -> 53.1μs (10.9% faster)


def test_input_ids_device_cpu():
    # Test with input_ids on CPU
    input_ids = torch.tensor([[1, 2, 3, 4]], device="cpu")
    pad_token_id = 0
    decoder_start_token_id = 101
    expected = torch.tensor([[101, 1, 2, 3]], device="cpu")
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 52.7μs -> 47.7μs (10.4% faster)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_single_element():
    # Test with a single element tensor
    input_ids = torch.tensor([[42]])
    pad_token_id = 0
    decoder_start_token_id = 101
    expected = torch.tensor([[101]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 57.5μs -> 51.4μs (11.8% faster)


def test_input_ids_with_large_values():
    # Test with large values in input_ids
    input_ids = torch.tensor([[2**31 - 1, -100, 0]])
    pad_token_id = 0
    decoder_start_token_id = 101
    expected = torch.tensor([[101, 2**31 - 1, 0]])
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 53.2μs -> 47.5μs (12.2% faster)


# -------------------- Large Scale Test Cases --------------------


def test_large_batch_and_sequence():
    # Test with a large batch and sequence length, within 100MB limit
    batch_size = 512
    seq_len = 128
    input_ids = torch.randint(low=0, high=1000, size=(batch_size, seq_len), dtype=torch.long)
    pad_token_id = 0
    decoder_start_token_id = 101
    # Manually compute expected for first two rows for sanity check
    expected_first_two = torch.zeros((2, seq_len), dtype=torch.long)
    expected_first_two[:, 1:] = input_ids[:2, :-1]
    expected_first_two[:, 0] = decoder_start_token_id
    expected_first_two.masked_fill_(expected_first_two == -100, pad_token_id)
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 180μs -> 118μs (52.7% faster)


def test_large_sequence_with_negative100():
    # Test with a large sequence and scattered -100 values
    batch_size = 10
    seq_len = 900
    input_ids = torch.randint(low=0, high=1000, size=(batch_size, seq_len), dtype=torch.long)
    # Insert -100 at random positions for each row
    for i in range(batch_size):
        input_ids[i, i * 10 % seq_len] = -100
    pad_token_id = 0
    decoder_start_token_id = 101
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 61.4μs -> 48.3μs (27.2% faster)


def test_large_all_negative100():
    # Test with a large tensor filled with -100
    batch_size = 20
    seq_len = 500
    input_ids = torch.full((batch_size, seq_len), -100, dtype=torch.long)
    pad_token_id = 7
    decoder_start_token_id = 101
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 79.6μs -> 64.9μs (22.7% faster)


def test_large_randomized():
    # Test with random input and random pad_token_id, decoder_start_token_id
    batch_size = 50
    seq_len = 800
    input_ids = torch.randint(-100, 1000, (batch_size, seq_len), dtype=torch.long)
    pad_token_id = 123
    decoder_start_token_id = 456
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 153μs -> 115μs (33.3% faster)


def test_large_batch_single_sequence():
    # Test with large batch, single sequence element
    batch_size = 1000
    seq_len = 1
    input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long)
    pad_token_id = 0
    decoder_start_token_id = 101
    codeflash_output = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)
    result = codeflash_output  # 55.7μs -> 48.3μs (15.3% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-shift_tokens_right-miabqwqj and push.

Codeflash Static Badge

The optimized code achieves a **27% speedup** through two key performance improvements:

**1. Early validation check**: Moving the `pad_token_id is None` check to the beginning eliminates wasted computation when inputs are invalid, providing dramatic speedups (up to 4581% in error cases).

**2. More efficient tensor allocation**: Replacing `input_ids.new_zeros()` with `torch.empty_like()` eliminates unnecessary zero-initialization since the tensor gets completely overwritten anyway. This reduces memory bandwidth overhead.

**3. Removing unnecessary `.clone()`**: The original code used `.clone()` when copying `input_ids[:, :-1]`, which creates an extra tensor copy. The optimized version removes this since we're already writing to a new tensor.

**Performance impact by workload**:
- **Small tensors (basic cases)**: 11-17% speedup from reduced allocation overhead
- **Large tensors (1000x100)**: Up to 48% speedup as memory bandwidth savings compound
- **Error cases**: Massive 4581% speedup from early validation

**Hot path significance**: Based on the function references, `shift_tokens_right` is called during model forward passes in both base model and conditional generation scenarios. It's used to prepare decoder inputs from either `input_ids` or `labels`, making it a critical path during training and inference where these optimizations will have meaningful impact on overall model performance.

The optimizations are particularly effective for transformer workloads that process large batches with long sequences, which is typical in modern NLP applications.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 22, 2025 13:28
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant