fix stride in generate
Browse files
model.py
CHANGED
@@ -1888,7 +1888,7 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
1888 |
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
1889 |
|
1890 |
# 2. set global generate variables
|
1891 |
-
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
1892 |
num_segment_frames = input_stride * self.config.max_source_positions
|
1893 |
batch_size, total_input_frames = self._retrieve_total_input_frames(
|
1894 |
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
|
|
1888 |
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
1889 |
|
1890 |
# 2. set global generate variables
|
1891 |
+
input_stride = self.model.encoder.get_conv_stride() #self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
1892 |
num_segment_frames = input_stride * self.config.max_source_positions
|
1893 |
batch_size, total_input_frames = self._retrieve_total_input_frames(
|
1894 |
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|