diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 1d24a159e4..500f0ef413 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -72,14 +72,18 @@ def forward(self, x: torch.Tensor, chunk_size: int = -1) -> torch.Tensor: left_edge = self.chunkwise_conv_scale[0] right_edge = self.chunkwise_conv_scale[1] - # seq_len >= kernel_size in non-streaming mode, so we pad with zeros - t = seq_len - self.kernel_size - channels = left_edge.shape[0] - pad = torch.zeros( - channels, t, device=left_edge.device, dtype=left_edge.dtype - ) - left_edge = torch.cat((left_edge, pad), dim=-1) - right_edge = torch.cat((pad, right_edge), dim=-1) + + if seq_len < self.kernel_size: + left_edge = left_edge[:, :seq_len] + right_edge = right_edge[:, -seq_len:] + else: + t = seq_len - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) chunk_scale = 1.0 + (left_edge + right_edge) x_chunk = x_chunk * chunk_scale