Harveenchadha
commited on
Commit
•
8b8bd52
1
Parent(s):
90f5dcf
Update app.py
Browse files
app.py
CHANGED
@@ -22,22 +22,23 @@ def read_file(wav):
|
|
22 |
|
23 |
|
24 |
def parse_transcription_with_lm(wav_file):
|
25 |
-
|
26 |
-
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
27 |
|
28 |
with torch.no_grad():
|
29 |
-
logits = model(**
|
30 |
-
int_result = processor.
|
31 |
|
32 |
transcription = int_result.text
|
33 |
return transcription
|
34 |
|
35 |
|
36 |
-
def
|
37 |
filename = wav_file.split('.')[0]
|
38 |
convert(wav_file, filename + "16k.wav")
|
39 |
speech, _ = sf.read(filename + "16k.wav")
|
40 |
-
|
|
|
|
|
41 |
|
42 |
def parse(wav_file, applyLM):
|
43 |
if applyLM:
|
@@ -46,12 +47,7 @@ def parse(wav_file, applyLM):
|
|
46 |
return parse_transcription(wav_file)
|
47 |
|
48 |
def parse_transcription(wav_file):
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
#speech = read_file(wav_file)
|
53 |
-
input_values = processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
|
54 |
-
|
55 |
logits = model(input_values).logits
|
56 |
predicted_ids = torch.argmax(logits, dim=-1)
|
57 |
|
|
|
22 |
|
23 |
|
24 |
def parse_transcription_with_lm(wav_file):
|
25 |
+
input_values = read_file_and_process(wav_file)
|
|
|
26 |
|
27 |
with torch.no_grad():
|
28 |
+
logits = model(**input_values).logits
|
29 |
+
int_result = processor.decode(logits.cpu().numpy())
|
30 |
|
31 |
transcription = int_result.text
|
32 |
return transcription
|
33 |
|
34 |
|
35 |
+
def read_file_and_process(wav_file):
|
36 |
filename = wav_file.split('.')[0]
|
37 |
convert(wav_file, filename + "16k.wav")
|
38 |
speech, _ = sf.read(filename + "16k.wav")
|
39 |
+
inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
|
40 |
+
|
41 |
+
return inputs
|
42 |
|
43 |
def parse(wav_file, applyLM):
|
44 |
if applyLM:
|
|
|
47 |
return parse_transcription(wav_file)
|
48 |
|
49 |
def parse_transcription(wav_file):
|
50 |
+
input_values = read_file_and_process(wav_file)
|
|
|
|
|
|
|
|
|
|
|
51 |
logits = model(input_values).logits
|
52 |
predicted_ids = torch.argmax(logits, dim=-1)
|
53 |
|