Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions src/transformers/models/speecht5/feature_extraction_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,41 @@ def zero_mean_unit_var_norm(
"""
Every array in the list is normalized to have zero mean and unit variance
"""
if attention_mask is not None:
attention_mask = np.array(attention_mask, np.int32)
normed_input_values = []

for vector, length in zip(input_values, attention_mask.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value

normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]

# Microoptimize as much as possible using numpy vectorization and memoryviews
# Fast path: no attention mask
if attention_mask is None:
# Avoid list comprehension when possible using numpy array ops for sequences of same length
# If input_values contains ndarrays of same length, stack and normalize at once. If not, fallback.
lengths = [x.shape[0] for x in input_values]
all_same_length = all(l == lengths[0] for l in lengths)
if all_same_length and len(input_values) > 0:
arr = np.stack(input_values)
mean = arr.mean(axis=1, keepdims=True)
var = arr.var(axis=1, keepdims=True)
normed = (arr - mean) / np.sqrt(var + 1e-7)
return [normed[i] for i in range(normed.shape[0])]
else:
# Fallback to original list comprehension but use numpy's fast calculation
return [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]

# With attention mask
attention_mask_arr = np.array(attention_mask, np.int32)
normed_input_values = []
# Precompute sum(-1) once:
sample_lengths = attention_mask_arr.sum(-1)
# For each vector, only normalize valid (non-padded) region, assign padding_value to rest (in-place).
for vector, length in zip(input_values, sample_lengths):
# Use slice for valid region
valid = vector[:length]
mean = valid.mean()
var = valid.var()
# Allocate output array only once, and copy data efficiently
normed_slice = np.empty_like(vector)
normed_norm = 1.0 / np.sqrt(var + 1e-7)
normed_slice[:length] = (vector[:length] - mean) * normed_norm
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
return normed_input_values

def _extract_mel_features(
Expand Down