⚡️ Speed up method SpeechT5FeatureExtractor.zero_mean_unit_var_norm by 104%
#378
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 104% (1.04x) speedup for
SpeechT5FeatureExtractor.zero_mean_unit_var_norminsrc/transformers/models/speecht5/feature_extraction_speecht5.py⏱️ Runtime :
37.3 milliseconds→18.3 milliseconds(best of78runs)📝 Explanation and details
The optimized code achieves a 103% speedup by implementing two key optimizations in the
zero_mean_unit_var_normmethod:1. Vectorized Batch Processing for Same-Length Arrays (No Attention Mask)
When no attention mask is provided and all input arrays have the same length, the optimization uses
np.stack()to combine arrays into a single 2D array, then performs vectorized mean and variance calculations usingaxis=1. This eliminates the expensive per-array Python loop and leverages NumPy's highly optimized BLAS operations.2. Optimized Memory Allocation and Computation Order (With Attention Mask)
For attention mask cases, the optimization:
attention_mask.sum(-1)once instead of recalculating for each iterationnp.empty_like()for faster memory allocation without initialization1.0 / np.sqrt(var + 1e-7)) to avoid redundant square root operationsPerformance Impact Analysis:
test_large_batch_no_attention_maskshows 482% speedup (8.95ms → 1.54ms)test_large_batch_sizeachieves 1887% speedup (12.6ms → 633μs)Why This Matters:
Audio feature extraction typically processes batches of similar-length audio segments. The vectorized approach transforms O(n) individual operations into O(1) batch operations for the common case, making it highly effective for speech processing pipelines where consistent audio chunk sizes are standard. The optimizations are particularly beneficial when processing multiple audio files or segments in batch inference scenarios.
The slight overhead for single small vectors is acceptable given the substantial gains for batch processing, which is the primary use case in production ML workflows.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
To edit these changes
git checkout codeflash/optimize-SpeechT5FeatureExtractor.zero_mean_unit_var_norm-mia34briand push.