[Bugfix] Fix Var Length Batched Padding in Granite Speech #31906
+8
−3
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Purpose
Fixes a bug in granite speech padding - the features are variable length, so we pad tensors to be
[bsz, longest_feature, 160], but when the multimodal inputs are batched, they are provided as a list of dim[feat_len, 160], which breaks the pad call expecting a 3D tensor.This PR unsqueezes the features if they're 2D to ensure the padding is 3D.
Test Plan
Can be reproduced with something like this:
import io import librosa import numpy as np from numpy import ndarray as NDArray import soundfile as sf import wave from transformers import AutoTokenizer from tqdm import tqdm from vllm import LLM from vllm.lora.request import LoRARequest from vllm import LLM TEST_FILES = [ "1.wav", "2.wav", "3.wav", "4.wav", "5.wav", "6.wav", ] def audio_to_wav(audio_frame: bytes) -> bytes: """Convert mono pcm 16000 sample rate audio bytes to WAV content returned as bytes Args: audio_frame (bytes): audio bytes Returns: bytes: WAV bytes """ wav_buffer = io.BytesIO() with wave.open(wav_buffer, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(16000) wf.writeframes(audio_frame) wav_buffer.seek(0) return wav_buffer.read() def float2int(sound: NDArray) -> NDArray: """Convert the NDArray containing sound from floats in the range [-1.0,1.0] to int16 in the range [-32768,32767] Args: sound (NDArray): array of float32 sound values Returns: NDArray: sound as int16 sound values between [-32768,32767] """ return (sound * 32768.0).astype(np.int16) def read_float32_wav_to_wav(filepath: str) -> bytes | None: """ Reads a 32-bit floating point PCM WAV file and returns its raw audio data as bytes. Args: filepath (str): The path to the WAV file. Returns: bytes: The raw audio data as a bytes object. """ try: # Read the audio data and sample rate data, samplerate = sf.read(filepath, dtype="float32") data = float2int(data) audio_bytes = data.tobytes() return audio_to_wav(audio_bytes) except Exception as e: print(f"Error reading WAV file: {e}") return None MODEL = "ibm-granite/granite-speech-3.3-8b" tokenizer = AutoTokenizer.from_pretrained(MODEL) def get_prompt(question: str, has_audio: bool): """Build the input prompt to send to vLLM.""" if has_audio: question = f"<|audio|>{question}" chat = [ { "role": "user", "content": question } ] return tokenizer.apply_chat_template(chat, tokenize=False) question = "can you transcribe the speech into a written format?" model = LLM( model=MODEL, enable_lora=True, max_lora_rank=64, limit_mm_per_prompt={"audio": 1}, enforce_eager=True, ) prompt_with_audio = "<|start_of_role|>system<|end_of_role|> Knowledge Cutoff Date: April 2024.\n Today's Date: November 21, 2025. You are Granite, developed by IBM. You are a helpful AI assistant.<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>\ncan you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" batch = [] for test_file in tqdm(TEST_FILES): audio_array, sampling_rate = librosa.load(test_file, sr=None) inputs = { "prompt": prompt_with_audio, "multi_modal_data": { "audio": audio_array, } } batch.append(inputs) outputs = model.generate( batch, lora_request=LoRARequest("speech", 1, MODEL), use_tqdm=False, ) for output in outputs: print(output.outputs[0].text)Test Result
It pads correctly and doesn't crash.
@DarkLight1337 could you please take a look?