vineelpratap commited on
Commit
4090e0d
·
1 Parent(s): 5cc287f

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +34 -35
asr.py CHANGED
@@ -30,9 +30,40 @@ lm_decoding_configfile = hf_hub_download(
30
  with open(lm_decoding_configfile) as f:
31
  lm_decoding_config = json.loads(f.read())
32
 
33
- # allow language model decoding for specific languages
34
- lm_decode_isos = ["eng"]
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def transcribe(
38
  audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
@@ -75,42 +106,10 @@ def transcribe(
75
  with torch.no_grad():
76
  outputs = model(**inputs).logits
77
 
78
- if lang_code not in lm_decoding_config or lang_code not in lm_decode_isos:
79
  ids = torch.argmax(outputs, dim=-1)[0]
80
  transcription = processor.decode(ids)
81
  else:
82
- decoding_config = lm_decoding_config[lang_code]
83
-
84
- lm_file = hf_hub_download(
85
- repo_id="facebook/mms-cclms",
86
- filename=decoding_config["lmfile"].rsplit("/", 1)[1],
87
- subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
88
- )
89
- token_file = hf_hub_download(
90
- repo_id="facebook/mms-cclms",
91
- filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
92
- subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
93
- )
94
- lexicon_file = None
95
- if decoding_config["lexiconfile"] is not None:
96
- lexicon_file = hf_hub_download(
97
- repo_id="facebook/mms-cclms",
98
- filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
99
- subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
100
- )
101
-
102
- beam_search_decoder = ctc_decoder(
103
- lexicon=lexicon_file,
104
- tokens=token_file,
105
- lm=lm_file,
106
- nbest=1,
107
- beam_size=500,
108
- beam_size_token=50,
109
- lm_weight=float(decoding_config["lmweight"]),
110
- word_score=float(decoding_config["wordscore"]),
111
- sil_score=float(decoding_config["silweight"]),
112
- blank_token="<s>",
113
- )
114
  beam_search_result = beam_search_decoder(outputs.to("cpu"))
115
  transcription = " ".join(beam_search_result[0][0].words).strip()
116
 
 
30
  with open(lm_decoding_configfile) as f:
31
  lm_decoding_config = json.loads(f.read())
32
 
33
+ # allow language model decoding for "eng"
 
34
 
35
+ decoding_config = lm_decoding_config["eng"]
36
+
37
+ lm_file = hf_hub_download(
38
+ repo_id="facebook/mms-cclms",
39
+ filename=decoding_config["lmfile"].rsplit("/", 1)[1],
40
+ subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
41
+ )
42
+ token_file = hf_hub_download(
43
+ repo_id="facebook/mms-cclms",
44
+ filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
45
+ subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
46
+ )
47
+ lexicon_file = None
48
+ if decoding_config["lexiconfile"] is not None:
49
+ lexicon_file = hf_hub_download(
50
+ repo_id="facebook/mms-cclms",
51
+ filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
52
+ subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
53
+ )
54
+
55
+ beam_search_decoder = ctc_decoder(
56
+ lexicon=lexicon_file,
57
+ tokens=token_file,
58
+ lm=lm_file,
59
+ nbest=1,
60
+ beam_size=500,
61
+ beam_size_token=50,
62
+ lm_weight=float(decoding_config["lmweight"]),
63
+ word_score=float(decoding_config["wordscore"]),
64
+ sil_score=float(decoding_config["silweight"]),
65
+ blank_token="<s>",
66
+ )
67
 
68
  def transcribe(
69
  audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
 
106
  with torch.no_grad():
107
  outputs = model(**inputs).logits
108
 
109
+ if lang_code != "eng":
110
  ids = torch.argmax(outputs, dim=-1)[0]
111
  transcription = processor.decode(ids)
112
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  beam_search_result = beam_search_decoder(outputs.to("cpu"))
114
  transcription = " ".join(beam_search_result[0][0].words).strip()
115