File size: 50,924 Bytes
51bc847 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 |
""" Specifies the inference interfaces for Automatic speech Recognition (ASR) modules.
Authors:
* Aku Rouhe 2021
* Peter Plantinga 2021
* Loren Lugosch 2020
* Mirco Ravanelli 2020
* Titouan Parcollet 2021
* Abdel Heba 2021
* Andreas Nautsch 2022, 2023
* Pooneh Mousavi 2023
* Sylvain de Langen 2023, 2024
* Adel Moumen 2023, 2024
* Pradnya Kandarkar 2023
"""
import functools
import itertools
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple
import sentencepiece
import torch
import torchaudio
from tqdm import tqdm
import speechbrain
from speechbrain.inference.interfaces import Pretrained
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
from speechbrain.utils.fetching import fetch
from speechbrain.utils.streaming import split_fixed_chunks
class EncoderDecoderASR(Pretrained):
"""A ready-to-use Encoder-Decoder ASR model
The class can be used either to run only the encoder (encode()) to extract
features or to run the entire encoder-decoder model
(transcribe()) to transcribe speech. The given YAML must contain the fields
specified in the *_NEEDED[] lists.
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> from speechbrain.inference.ASR import EncoderDecoderASR
>>> tmpdir = getfixture("tmpdir")
>>> asr_model = EncoderDecoderASR.from_hparams(
... source="speechbrain/asr-crdnn-rnnlm-librispeech",
... savedir=tmpdir,
... ) # doctest: +SKIP
>>> asr_model.transcribe_file("tests/samples/single-mic/example2.flac") # doctest: +SKIP
"MY FATHER HAS REVEALED THE CULPRIT'S NAME"
"""
HPARAMS_NEEDED = ["tokenizer"]
MODULES_NEEDED = ["encoder", "decoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.hparams.tokenizer
self.transducer_beam_search = False
self.transformer_beam_search = False
if hasattr(self.hparams, "transducer_beam_search"):
self.transducer_beam_search = self.hparams.transducer_beam_search
if hasattr(self.hparams, "transformer_beam_search"):
self.transformer_beam_search = self.hparams.transformer_beam_search
def transcribe_file(self, path, **kwargs):
"""Transcribes the given audiofile into a sequence of words.
Arguments
---------
path : str
Path to audio file which to transcribe.
**kwargs : dict
Arguments forwarded to ``load_audio``.
Returns
-------
str
The audiofile transcription produced by this ASR system.
"""
waveform = self.load_audio(path, **kwargs)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_words, predicted_tokens = self.transcribe_batch(
batch, rel_length
)
return predicted_words[0]
def encode_batch(self, wavs, wav_lens):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.Tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.Tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.Tensor
The encoded batch
"""
wavs = wavs.float()
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.mods.encoder(wavs, wav_lens)
if self.transformer_beam_search:
encoder_out = self.mods.transformer.encode(encoder_out, wav_lens)
return encoder_out
def transcribe_batch(self, wavs, wav_lens):
"""Transcribes the input audio into a sequence of words
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.Tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.Tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch transcribed.
tensor
Each predicted token id.
"""
with torch.no_grad():
wav_lens = wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs, wav_lens)
if self.transducer_beam_search:
inputs = [encoder_out]
else:
inputs = [encoder_out, wav_lens]
predicted_tokens, _, _, _ = self.mods.decoder(*inputs)
predicted_words = [
self.tokenizer.decode_ids(token_seq)
for token_seq in predicted_tokens
]
return predicted_words, predicted_tokens
def forward(self, wavs, wav_lens):
"""Runs full transcription - note: no gradients through decoding"""
return self.transcribe_batch(wavs, wav_lens)
class EncoderASR(Pretrained):
"""A ready-to-use Encoder ASR model
The class can be used either to run only the encoder (encode()) to extract
features or to run the entire encoder + decoder function model
(transcribe()) to transcribe speech. The given YAML must contain the fields
specified in the *_NEEDED[] lists.
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> from speechbrain.inference.ASR import EncoderASR
>>> tmpdir = getfixture("tmpdir")
>>> asr_model = EncoderASR.from_hparams(
... source="speechbrain/asr-wav2vec2-commonvoice-fr",
... savedir=tmpdir,
... ) # doctest: +SKIP
>>> asr_model.transcribe_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
"""
HPARAMS_NEEDED = ["tokenizer", "decoding_function"]
MODULES_NEEDED = ["encoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.hparams.tokenizer
self.set_decoding_function()
def set_decoding_function(self):
"""Set the decoding function based on the parameters defined in the hyperparameter file.
The decoding function is determined by the `decoding_function` specified in the hyperparameter file.
It can be either a functools.partial object representing a decoding function or an instance of
`speechbrain.decoders.ctc.CTCBaseSearcher` for beam search decoding.
Raises:
ValueError: If the decoding function is neither a functools.partial nor an instance of
speechbrain.decoders.ctc.CTCBaseSearcher.
Note:
- For greedy decoding (functools.partial), the provided `decoding_function` is assigned directly.
- For CTCBeamSearcher decoding, an instance of the specified `decoding_function` is created, and
additional parameters are added based on the tokenizer type.
"""
# Greedy Decoding case
if isinstance(self.hparams.decoding_function, functools.partial):
self.decoding_function = self.hparams.decoding_function
# CTCBeamSearcher case
else:
# 1. check if the decoding function is an instance of speechbrain.decoders.CTCBaseSearcher
if issubclass(
self.hparams.decoding_function,
speechbrain.decoders.ctc.CTCBaseSearcher,
):
# If so, we need to retrieve the vocab list from the tokenizer.
# We also need to check if the tokenizer is a sentencepiece or a CTCTextEncoder.
if isinstance(
self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder
):
ind2lab = self.tokenizer.ind2lab
vocab_list = [ind2lab[x] for x in range(len(ind2lab))]
elif isinstance(
self.tokenizer, sentencepiece.SentencePieceProcessor
):
vocab_list = [
self.tokenizer.id_to_piece(i)
for i in range(self.tokenizer.vocab_size())
]
else:
raise ValueError(
"The tokenizer must be sentencepiece or CTCTextEncoder"
)
# We can now instantiate the decoding class and add all the parameters
if hasattr(self.hparams, "test_beam_search"):
opt_beam_search_params = self.hparams.test_beam_search
# check if the kenlm_model_path is provided and fetch it if necessary
if "kenlm_model_path" in opt_beam_search_params:
source, fl = split_path(
opt_beam_search_params["kenlm_model_path"]
)
kenlm_model_path = str(
fetch(
fl, source=source, savedir=self.hparams.savedir
)
)
# we need to update the kenlm_model_path in the opt_beam_search_params
opt_beam_search_params["kenlm_model_path"] = (
kenlm_model_path
)
else:
opt_beam_search_params = {}
self.decoding_function = self.hparams.decoding_function(
**opt_beam_search_params, vocab_list=vocab_list
)
else:
raise ValueError(
"The decoding function must be an instance of speechbrain.decoders.CTCBaseSearcher"
)
def transcribe_file(self, path, **kwargs):
"""Transcribes the given audiofile into a sequence of words.
Arguments
---------
path : str
Path to audio file which to transcribe.
**kwargs : dict
Arguments forwarded to ``load_audio``.
Returns
-------
str
The audiofile transcription produced by this ASR system.
"""
waveform = self.load_audio(path, **kwargs)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_words, predicted_tokens = self.transcribe_batch(
batch, rel_length
)
return str(predicted_words[0])
def encode_batch(self, wavs, wav_lens):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.Tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.Tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.Tensor
The encoded batch
"""
wavs = wavs.float()
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.mods.wav2vec(wavs, wav_lens)
x = self.mods.dec(encoder_out)
logits = self.mods.output_lin(x)
p_ctc = self.hparams.softmax(logits)
return p_ctc
def transcribe_batch(self, wavs, wav_lens):
"""Transcribes the input audio into a sequence of words
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.Tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.Tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch transcribed.
tensor
Each predicted token id.
"""
with torch.no_grad():
wav_lens = wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs, wav_lens)
predictions = self.decoding_function(encoder_out, wav_lens)
print(predictions)
is_ctc_text_encoder_tokenizer = isinstance(
self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder
)
self.tokenizer.load('sample_data/SLU/labelencoder.txt')
if isinstance(self.hparams.decoding_function, functools.partial):
if is_ctc_text_encoder_tokenizer:
predicted_words = [
"".join(self.tokenizer.decode_ndim(token_seq))
for token_seq in predictions
]
else:
predicted_words = [
self.tokenizer.decode_ids(token_seq)
for token_seq in predictions
]
else:
predicted_words = [hyp[0].text for hyp in predictions]
return predicted_words, predictions
def forward(self, wavs, wav_lens):
"""Runs the encoder"""
return self.encode_batch(wavs, wav_lens)
@dataclass
class ASRWhisperSegment:
"""A single chunk of audio for Whisper ASR streaming.
This object is intended to be mutated as streaming progresses and passed across calls
to the lower-level APIs such as `encode_chunk`, `decode_chunk`, etc.
Attributes
----------
start : float
The start time of the audio chunk.
end : float
The end time of the audio chunk.
chunk : torch.Tensor
The audio chunk, shape [time, channels].
lang_id : str
The language identifier associated with the audio chunk.
words : str
The predicted words for the audio chunk.
tokens : List[int]
The predicted tokens for the audio chunk.
prompt : List[str]
The prompt associated with the audio chunk.
avg_log_probs : float
The average log probability associated with the prediction.
no_speech_prob : float
The probability of no speech in the audio chunk.
"""
start: float
end: float
chunk: torch.Tensor
lang_id: Optional[str] = None
words: Optional[str] = None
tokens: Optional[List[str]] = None
prompt: Optional[List[str]] = None
avg_log_probs: Optional[float] = None
no_speech_prob: Optional[float] = None
class WhisperASR(Pretrained):
"""A ready-to-use Whisper ASR model.
The class can be used to run the entire encoder-decoder whisper model.
The set of tasks supported are: ``transcribe``, ``translate``, and ``lang_id``.
The given YAML must contains the fields specified in the *_NEEDED[] lists.
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> from speechbrain.inference.ASR import WhisperASR
>>> tmpdir = getfixture("tmpdir")
>>> asr_model = WhisperASR.from_hparams(source="speechbrain/asr-whisper-medium-commonvoice-it", savedir=tmpdir,) # doctest: +SKIP
>>> hyp = asr_model.transcribe_file("speechbrain/asr-whisper-medium-commonvoice-it/example-it.wav") # doctest: +SKIP
>>> hyp # doctest: +SKIP
buongiorno a tutti e benvenuti a bordo
>>> _, probs = asr_model.detect_language_file("speechbrain/asr-whisper-medium-commonvoice-it/example-it.wav") # doctest: +SKIP
>>> print(f"Detected language: {max(probs[0], key=probs[0].get)}") # doctest: +SKIP
Detected language: it
"""
HPARAMS_NEEDED = ["language", "sample_rate"]
MODULES_NEEDED = ["whisper", "decoder"]
TASKS = ["transcribe", "translate", "lang_id"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.hparams.whisper.tokenizer
@torch.no_grad()
def detect_language_file(self, path: str):
"""Detects the language of the given audiofile.
This method only works on input_file of 30 seconds or less.
Arguments
---------
path : str
Path to audio file which to transcribe.
Returns
-------
language_tokens : torch.Tensor
The detected language tokens.
language_probs : dict
The probabilities of the detected language tokens.
Raises
------
ValueError
If the model doesn't have language tokens.
"""
wavs = self.load_audio(path).float().to(self.device).unsqueeze(0)
mel = self.mods.whisper._get_mel(wavs)
language_tokens, language_probs = self.mods.whisper.detect_language(mel)
return language_tokens, language_probs
@torch.no_grad()
def detect_language_batch(self, wav: torch.Tensor):
"""Detects the language of the given wav Tensor.
This method only works on wav files of 30 seconds or less.
Arguments
---------
wav : torch.tensor
Batch of waveforms [batch, time, channels].
Returns
-------
language_tokens : torch.Tensor of shape (batch_size,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]]
list of dictionaries containing the probability distribution over all languages.
Raises
------
ValueError
If the model doesn't have language tokens.
Example
-------
>>> from speechbrain.inference.ASR import WhisperASR
>>> import torchaudio
>>> tmpdir = getfixture("tmpdir")
>>> asr_model = WhisperASR.from_hparams(
... source="speechbrain/asr-whisper-medium-commonvoice-it",
... savedir=tmpdir,
... ) # doctest: +SKIP
>>> wav, _ = torchaudio.load("your_audio") # doctest: +SKIP
>>> language_tokens, language_probs = asr_model.detect_language(wav) # doctest: +SKIP
"""
mel = self.mods.whisper._get_mel(wav)
language_tokens, language_probs = self.mods.whisper.detect_language(mel)
return language_tokens, language_probs
@torch.no_grad()
def _detect_language(self, mel: torch.Tensor, task: str):
"""Detects the language of the given mel spectrogram.
Arguments
---------
mel : torch.tensor
Batch of mel spectrograms [batch, time, channels].
task : str
The task to perform.
Returns
-------
language_tokens : Tensor, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
languages = [self.mods.whisper.language] * mel.shape[0]
lang_probs = None
if self.mods.whisper.language is None or task == "lang_id":
lang_tokens, lang_probs = self.mods.whisper.detect_language(mel)
languages = [max(probs, key=probs.get) for probs in lang_probs]
self.mods.decoder.set_lang_tokens(lang_tokens)
return languages, lang_probs
def _get_audio_stream(
self, streamer: "torchaudio.io.StreamReader", frames_per_chunk: int
):
"""From a :class:`torchaudio.io.StreamReader`, identifies the audio
stream and returns an iterable stream of chunks (after resampling and
downmixing to mono).
Arguments
---------
streamer : torchaudio.io.StreamReader
The stream object. Must hold exactly one source stream of an
audio type.
frames_per_chunk : int
The number of frames per chunk. For a streaming model, this should
be determined from the DynChunkTrain configuration.
Yields
------
chunks from streamer
"""
stream_infos = [
streamer.get_src_stream_info(i)
for i in range(streamer.num_src_streams)
]
audio_stream_infos = [
(i, stream_info)
for i, stream_info in enumerate(stream_infos)
if stream_info.media_type == "audio"
]
if len(audio_stream_infos) != 1:
raise ValueError(
f"Expected stream to have only 1 stream (with any number of channels), got {len(audio_stream_infos)} (with streams: {stream_infos})"
)
# find the index of the first (and only) audio stream
audio_stream_index = audio_stream_infos[0][0]
# output stream #0
streamer.add_basic_audio_stream(
frames_per_chunk=frames_per_chunk,
stream_index=audio_stream_index,
sample_rate=self.audio_normalizer.sample_rate,
format="fltp", # torch.float32
num_channels=1,
)
for (chunk,) in streamer.stream():
chunk = chunk.squeeze(-1) # we deal with mono, remove that dim
chunk = chunk.unsqueeze(0) # create a fake batch dim
yield chunk
@torch.no_grad()
def transcribe_file_streaming(
self,
path: str,
task: Optional[str] = None,
initial_prompt: Optional[str] = None,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold=0.6,
condition_on_previous_text: bool = False,
verbose: bool = False,
use_torchaudio_streaming: bool = False,
chunk_size: Optional[int] = 30,
**kwargs,
):
"""Transcribes the given audiofile into a sequence of words.
This method supports the following tasks: ``transcribe``, ``translate``, and ``lang_id``.
It can process an input audio file longer than 30 seconds by splitting it into chunk_size-second segments.
Arguments
---------
path : str
URI/path to the audio to transcribe. When
``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
fetching from HF or a local file. When ``True``, resolves the URI
through ffmpeg, as documented in
:class:`torchaudio.io.StreamReader`.
task : Optional[str]
The task to perform. If None, the default task is the one passed in the Whisper model.
initial_prompt : Optional[str]
The initial prompt to condition the model on.
logprob_threshold : Optional[float]
The log probability threshold to continue decoding the current segment.
no_speech_threshold : float
The threshold to skip decoding segment if the no_speech_prob is higher than this value.
condition_on_previous_text : bool
If True, the model will be condition on the last 224 tokens.
verbose : bool
If True, print the transcription of each segment.
use_torchaudio_streaming : bool
Whether the audio file can be loaded in a streaming fashion. If not,
transcription is still performed through chunks of audio, but the
entire audio file is fetched and loaded at once.
This skips the usual fetching method and instead resolves the URI
using torchaudio (via ffmpeg).
chunk_size : Optional[int]
The size of the chunks to split the audio into. The default
chunk size is 30 seconds which corresponds to the maximal length
that the model can process in one go.
**kwargs : dict
Arguments forwarded to ``load_audio``
Yields
------
ASRWhisperSegment
A new ASRWhisperSegment instance initialized with the provided parameters.
"""
if task is not None:
if task in self.TASKS:
if task != "lang_id":
self.mods.decoder.set_task(task)
else:
raise ValueError(
f"Task {task} not supported. Supported tasks are {self.TASKS}"
)
# create chunks of chunk_size seconds
num_frames_per_chunk = chunk_size * self.hparams.sample_rate
if use_torchaudio_streaming:
streamer = torchaudio.io.StreamReader(path)
segments = self._get_audio_stream(streamer, num_frames_per_chunk)
else:
waveform = self.load_audio(path, **kwargs)
batch = waveform.unsqueeze(0)
segments = split_fixed_chunks(batch, num_frames_per_chunk)
rel_length = torch.tensor([1.0])
all_tokens = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = self.whisper.tokenizer.encode(
" " + initial_prompt.strip()
)
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
for i, segment in enumerate(tqdm(segments, disable=verbose)):
# move the segment on the device
segment = segment.to(self.device)
# extract mel spectrogram
mel_segment = self.mods.whisper._get_mel(segment)
start = i * chunk_size
end = (i + 1) * chunk_size
encoder_out = self.mods.whisper.forward_encoder(mel_segment)
languages, _ = self._detect_language(mel_segment, task)
if task == "lang_id":
yield ASRWhisperSegment(
start=start,
end=end,
chunk=segment,
lang_id=languages[0],
)
continue
prompt = all_tokens[prompt_reset_since:]
self.mods.decoder.set_prompt(prompt)
predicted_tokens, _, scores, _ = self.mods.decoder(
encoder_out, rel_length
)
avg_log_probs = scores.sum() / (len(predicted_tokens[0]) + 1)
if no_speech_threshold is not None:
should_skip = (
self.mods.decoder.no_speech_probs[0] > no_speech_threshold
)
if (
logprob_threshold is not None
and avg_log_probs > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
yield ASRWhisperSegment(
start=start,
end=end,
chunk=segment,
lang_id=languages[0],
words="",
tokens=[],
prompt=prompt,
avg_log_probs=avg_log_probs.item(),
no_speech_prob=self.mods.decoder.no_speech_probs[0],
)
continue
predicted_words = [
self.tokenizer.decode(t, skip_special_tokens=True).strip()
for t in predicted_tokens
]
yield ASRWhisperSegment(
start=start,
end=end,
chunk=segment,
lang_id=languages[0],
words=predicted_words[0],
tokens=predicted_tokens[0],
prompt=prompt,
avg_log_probs=avg_log_probs.item(),
no_speech_prob=self.mods.decoder.no_speech_probs[0],
)
all_tokens.extend(predicted_tokens[0])
if (
not condition_on_previous_text
or self.mods.decoder.temperature > 0.5
):
prompt_reset_since = len(all_tokens)
def transcribe_file(
self,
path: str,
task: Optional[str] = None,
initial_prompt: Optional[str] = None,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold=0.6,
condition_on_previous_text: bool = False,
verbose: bool = False,
use_torchaudio_streaming: bool = False,
chunk_size: Optional[int] = 30,
**kwargs,
) -> List[ASRWhisperSegment]:
"""Run the Whisper model using the specified task on the given audio file and return the ``ASRWhisperSegment`` objects
for each segment.
This method supports the following tasks: ``transcribe``, ``translate``, and ``lang_id``.
It can process an input audio file longer than 30 seconds by splitting it into chunk_size-second segments.
Arguments
---------
path : str
URI/path to the audio to transcribe. When
``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
fetching from HF or a local file. When ``True``, resolves the URI
through ffmpeg, as documented in
:class:`torchaudio.io.StreamReader`.
task : Optional[str]
The task to perform. If None, the default task is the one passed in the Whisper model.
It can be one of the following: ``transcribe``, ``translate``, ``lang_id``.
initial_prompt : Optional[str]
The initial prompt to condition the model on.
logprob_threshold : Optional[float]
The log probability threshold to continue decoding the current segment.
no_speech_threshold : float
The threshold to skip decoding segment if the no_speech_prob is higher than this value.
condition_on_previous_text : bool
If True, the model will be condition on the last 224 tokens.
verbose : bool
If True, print the details of each segment.
use_torchaudio_streaming : bool
Whether the audio file can be loaded in a streaming fashion. If not,
transcription is still performed through chunks of audio, but the
entire audio file is fetched and loaded at once.
This skips the usual fetching method and instead resolves the URI
using torchaudio (via ffmpeg).
chunk_size : Optional[int]
The size of the chunks to split the audio into. The default
chunk size is 30 seconds which corresponds to the maximal length
that the model can process in one go.
**kwargs : dict
Arguments forwarded to ``load_audio``
Returns
-------
results : list
A list of ``WhisperASRChunk`` objects, each containing the task result.
"""
results = []
for whisper_segment in self.transcribe_file_streaming(
path,
task=task,
initial_prompt=initial_prompt,
logprob_threshold=logprob_threshold,
no_speech_threshold=no_speech_threshold,
condition_on_previous_text=condition_on_previous_text,
verbose=verbose,
use_torchaudio_streaming=use_torchaudio_streaming,
chunk_size=chunk_size,
**kwargs,
):
results.append(whisper_segment)
if verbose:
pred = (
whisper_segment.words
if task != "lang_id"
else whisper_segment.lang_id
)
print(
f"[{whisper_segment.start}s --> {whisper_segment.end}s] {pred}"
)
return results
def encode_batch(self, wavs, wav_lens):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
wavs = wavs.to(device=self.device, dtype=torch.float32)
mel = self.mods.whisper._get_mel(wavs)
encoder_out = self.mods.whisper.forward_encoder(mel)
return encoder_out
@torch.no_grad()
def transcribe_batch(self, wavs, wav_lens):
"""Transcribes the input audio into a sequence of words
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch transcribed.
tensor
Each predicted token id.
"""
wav_lens = wav_lens.float().to(self.device)
encoder_out = self.encode_batch(wavs, wav_lens)
predicted_tokens, _, _, _ = self.mods.decoder(encoder_out, wav_lens)
predicted_words = [
self.tokenizer.decode(t, skip_special_tokens=True).strip()
for t in predicted_tokens
]
if self.hparams.normalized_transcripts:
predicted_words = [
self.tokenizer.normalize(text).split(" ")
for text in predicted_words
]
return predicted_words, predicted_tokens
def forward(self, wavs, wav_lens):
"""Runs full transcription - note: no gradients through decoding"""
return self.transcribe_batch(wavs, wav_lens)
@dataclass
class ASRStreamingContext:
"""Streaming metadata, initialized by
:meth:`~StreamingASR.make_streaming_context` (see there for details on
initialization of fields here).
This object is intended to be mutate: the same object should be passed
across calls as streaming progresses (namely when using the lower-level
:meth:`~StreamingASR.encode_chunk`, etc. APIs).
Holds some references to opaque streaming contexts, so the context is
model-agnostic to an extent."""
config: DynChunkTrainConfig
"""Dynamic chunk training configuration used to initialize the streaming
context. Cannot be modified on the fly."""
fea_extractor_context: Any
"""Opaque feature extractor streaming context."""
encoder_context: Any
"""Opaque encoder streaming context."""
decoder_context: Any
"""Opaque decoder streaming context."""
tokenizer_context: Optional[List[Any]]
"""Opaque streaming context for the tokenizer. Initially `None`. Initialized
to a list of tokenizer contexts once batch size can be determined."""
class StreamingASR(Pretrained):
"""A ready-to-use, streaming-capable ASR model.
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> from speechbrain.inference.ASR import StreamingASR
>>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
>>> tmpdir = getfixture("tmpdir")
>>> asr_model = StreamingASR.from_hparams(source="speechbrain/asr-conformer-streaming-librispeech", savedir=tmpdir,) # doctest: +SKIP
>>> asr_model.transcribe_file("speechbrain/asr-conformer-streaming-librispeech/test-en.wav", DynChunkTrainConfig(24, 8)) # doctest: +SKIP
"""
HPARAMS_NEEDED = [
"fea_streaming_extractor",
"make_decoder_streaming_context",
"decoding_function",
"make_tokenizer_streaming_context",
"tokenizer_decode_streaming",
]
MODULES_NEEDED = ["enc", "proj_enc"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.filter_props = self.hparams.fea_streaming_extractor.properties
def _get_audio_stream(
self, streamer: "torchaudio.io.StreamReader", frames_per_chunk: int
):
"""From a :class:`torchaudio.io.StreamReader`, identifies the audio
stream and returns an iterable stream of chunks (after resampling and
downmixing to mono).
Arguments
---------
streamer : torchaudio.io.StreamReader
The stream object. Must hold exactly one source stream of an
audio type.
frames_per_chunk : int
The number of frames per chunk. For a streaming model, this should
be determined from the DynChunkTrain configuration.
Yields
------
chunks from streamer
"""
stream_infos = [
streamer.get_src_stream_info(i)
for i in range(streamer.num_src_streams)
]
audio_stream_infos = [
(i, stream_info)
for i, stream_info in enumerate(stream_infos)
if stream_info.media_type == "audio"
]
if len(audio_stream_infos) != 1:
raise ValueError(
f"Expected stream to have only 1 stream (with any number of channels), got {len(audio_stream_infos)} (with streams: {stream_infos})"
)
# find the index of the first (and only) audio stream
audio_stream_index = audio_stream_infos[0][0]
# output stream #0
streamer.add_basic_audio_stream(
frames_per_chunk=frames_per_chunk,
stream_index=audio_stream_index,
sample_rate=self.audio_normalizer.sample_rate,
format="fltp", # torch.float32
num_channels=1,
)
for (chunk,) in streamer.stream():
chunk = chunk.squeeze(-1) # we deal with mono, remove that dim
chunk = chunk.unsqueeze(0) # create a fake batch dim
yield chunk
def transcribe_file_streaming(
self,
path,
dynchunktrain_config: DynChunkTrainConfig,
use_torchaudio_streaming: bool = True,
**kwargs,
):
"""Transcribes the given audio file into a sequence of words, in a
streaming fashion, meaning that text is being yield from this
generator, in the form of strings to concatenate.
Arguments
---------
path : str
URI/path to the audio to transcribe. When
``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
fetching from HF or a local file. When ``True``, resolves the URI
through ffmpeg, as documented in
:class:`torchaudio.io.StreamReader`.
dynchunktrain_config : DynChunkTrainConfig
Streaming configuration. Sane values and how much time chunks
actually represent is model-dependent.
use_torchaudio_streaming : bool
Whether the audio file can be loaded in a streaming fashion. If not,
transcription is still performed through chunks of audio, but the
entire audio file is fetched and loaded at once.
This skips the usual fetching method and instead resolves the URI
using torchaudio (via ffmpeg).
**kwargs : dict
Arguments forwarded to ``load_audio``
Yields
------
generator of str
An iterator yielding transcribed chunks (strings). There is a yield
for every chunk, even if the transcribed string for that chunk is an
empty string.
"""
chunk_size = self.get_chunk_size_frames(dynchunktrain_config)
if use_torchaudio_streaming:
streamer = torchaudio.io.StreamReader(path)
chunks = self._get_audio_stream(streamer, chunk_size)
else:
waveform = self.load_audio(path, **kwargs)
batch = waveform.unsqueeze(0) # create batch dim
chunks = split_fixed_chunks(batch, chunk_size)
rel_length = torch.tensor([1.0])
context = self.make_streaming_context(dynchunktrain_config)
final_chunks = [
torch.zeros((1, chunk_size), device=self.device)
] * self.hparams.fea_streaming_extractor.get_recommended_final_chunk_count(
chunk_size
)
for chunk in itertools.chain(chunks, final_chunks):
predicted_words = self.transcribe_chunk(context, chunk, rel_length)
yield predicted_words[0]
def transcribe_file(
self,
path,
dynchunktrain_config: DynChunkTrainConfig,
use_torchaudio_streaming: bool = True,
):
"""Transcribes the given audio file into a sequence of words.
Arguments
---------
path : str
URI/path to the audio to transcribe. When
``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
fetching from HF or a local file. When ``True``, resolves the URI
through ffmpeg, as documented in
:class:`torchaudio.io.StreamReader`.
dynchunktrain_config : DynChunkTrainConfig
Streaming configuration. Sane values and how much time chunks
actually represent is model-dependent.
use_torchaudio_streaming : bool
Whether the audio file can be loaded in a streaming fashion. If not,
transcription is still performed through chunks of audio, but the
entire audio file is fetched and loaded at once.
This skips the usual fetching method and instead resolves the URI
using torchaudio (via ffmpeg).
Returns
-------
str
The audio file transcription produced by this ASR system.
"""
pred = ""
for text_chunk in self.transcribe_file_streaming(
path, dynchunktrain_config, use_torchaudio_streaming
):
pred += text_chunk
return pred
def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig):
"""Create a blank streaming context to be passed around for chunk
encoding/transcription.
Arguments
---------
dynchunktrain_config : DynChunkTrainConfig
Streaming configuration. Sane values and how much time chunks
actually represent is model-dependent.
Returns
-------
ASRStreamingContext
"""
return ASRStreamingContext(
config=dynchunktrain_config,
fea_extractor_context=self.hparams.fea_streaming_extractor.make_streaming_context(),
encoder_context=self.mods.enc.make_streaming_context(
dynchunktrain_config
),
decoder_context=self.hparams.make_decoder_streaming_context(),
tokenizer_context=None,
)
def get_chunk_size_frames(
self, dynchunktrain_config: DynChunkTrainConfig
) -> int:
"""Returns the chunk size in actual audio samples, i.e. the exact
expected length along the time dimension of an input chunk tensor (as
passed to :meth:`~StreamingASR.encode_chunk` and similar low-level
streaming functions).
Arguments
---------
dynchunktrain_config : DynChunkTrainConfig
The streaming configuration to determine the chunk frame count of.
Returns
-------
chunk size
"""
return (self.filter_props.stride - 1) * dynchunktrain_config.chunk_size
@torch.no_grad()
def encode_chunk(
self,
context: ASRStreamingContext,
chunk: torch.Tensor,
chunk_len: Optional[torch.Tensor] = None,
):
"""Encoding of a batch of audio chunks into a batch of encoded
sequences.
For full speech-to-text offline transcription, use `transcribe_batch` or
`transcribe_file`.
Must be called over a given context in the correct order of chunks over
time.
Arguments
---------
context : ASRStreamingContext
Mutable streaming context object, which must be specified and reused
across calls when streaming.
You can obtain an initial context by calling
`asr.make_streaming_context(config)`.
chunk : torch.Tensor
The tensor for an audio chunk of shape `[batch size, time]`.
The time dimension must strictly match
`asr.get_chunk_size_frames(config)`.
The waveform is expected to be in the model's expected format (i.e.
the sampling rate must be correct).
chunk_len : torch.Tensor, optional
The relative chunk length tensor of shape `[batch size]`. This is to
be used when the audio in one of the chunks of the batch is ending
within this chunk.
If unspecified, equivalent to `torch.ones((batch_size,))`.
Returns
-------
torch.Tensor
Encoded output, of a model-dependent shape."""
if chunk_len is None:
chunk_len = torch.ones((chunk.size(0),))
chunk = chunk.float()
chunk, chunk_len = chunk.to(self.device), chunk_len.to(self.device)
assert chunk.shape[-1] <= self.get_chunk_size_frames(context.config)
x = self.hparams.fea_streaming_extractor(
chunk, context=context.fea_extractor_context, lengths=chunk_len
)
x = self.mods.enc.forward_streaming(x, context.encoder_context)
x = self.mods.proj_enc(x)
return x
@torch.no_grad()
def decode_chunk(
self, context: ASRStreamingContext, x: torch.Tensor
) -> Tuple[List[str], List[List[int]]]:
"""Decodes the output of the encoder into tokens and the associated
transcription.
Must be called over a given context in the correct order of chunks over
time.
Arguments
---------
context : ASRStreamingContext
Mutable streaming context object, which should be the same object
that was passed to `encode_chunk`.
x : torch.Tensor
The output of `encode_chunk` for a given chunk.
Returns
-------
list of str
Decoded tokens of length `batch_size`. The decoded strings can be
of 0-length.
list of list of output token hypotheses
List of length `batch_size`, each holding a list of tokens of any
length `>=0`.
"""
tokens = self.hparams.decoding_function(x, context.decoder_context)
# initialize token context for real now that we know the batch size
if context.tokenizer_context is None:
context.tokenizer_context = [
self.hparams.make_tokenizer_streaming_context()
for _ in range(len(tokens))
]
words = [
self.hparams.tokenizer_decode_streaming(
self.hparams.tokenizer, cur_tokens, context.tokenizer_context[i]
)
for i, cur_tokens in enumerate(tokens)
]
return words, tokens
def transcribe_chunk(
self,
context: ASRStreamingContext,
chunk: torch.Tensor,
chunk_len: Optional[torch.Tensor] = None,
):
"""Transcription of a batch of audio chunks into transcribed text.
Must be called over a given context in the correct order of chunks over
time.
Arguments
---------
context : ASRStreamingContext
Mutable streaming context object, which must be specified and reused
across calls when streaming.
You can obtain an initial context by calling
`asr.make_streaming_context(config)`.
chunk : torch.Tensor
The tensor for an audio chunk of shape `[batch size, time]`.
The time dimension must strictly match
`asr.get_chunk_size_frames(config)`.
The waveform is expected to be in the model's expected format (i.e.
the sampling rate must be correct).
chunk_len : torch.Tensor, optional
The relative chunk length tensor of shape `[batch size]`. This is to
be used when the audio in one of the chunks of the batch is ending
within this chunk.
If unspecified, equivalent to `torch.ones((batch_size,))`.
Returns
-------
str
Transcribed string for this chunk, might be of length zero.
"""
if chunk_len is None:
chunk_len = torch.ones((chunk.size(0),))
chunk = chunk.float()
chunk, chunk_len = chunk.to(self.device), chunk_len.to(self.device)
x = self.encode_chunk(context, chunk, chunk_len)
words, _ = self.decode_chunk(context, x)
return words
|