Harveenchadha commited on
Commit
5b6d82c
1 Parent(s): 07d59d1

Refactored Code

Browse files
Files changed (1) hide show
  1. app.py +18 -51
app.py CHANGED
@@ -2,97 +2,64 @@ import soundfile as sf
2
  import torch
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor,Wav2Vec2ProcessorWithLM
4
  import gradio as gr
5
- import scipy.signal as sps
6
  import sox
7
  import subprocess
8
 
9
- def convert(inputfile, outfile):
10
- sox_tfm = sox.Transformer()
11
- sox_tfm.set_output_format(
12
- file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
13
- )
14
- #print(this is not done)
15
- sox_tfm.build(inputfile, outfile)
16
 
17
- def read_file(wav):
18
- sample_rate, signal = wav
19
- signal = signal.mean(-1)
20
- number_of_samples = round(len(signal) * float(16000) / sample_rate)
21
- resampled_signal = sps.resample(signal, number_of_samples)
22
- return resampled_signal
 
 
23
 
24
 
25
  def resampler(input_file_path, output_file_path):
26
- #output_file_path = output_folder_path + input_file_path.split('/')[-1]
27
-
28
  command = (
29
  f"ffmpeg -hide_banner -loglevel panic -i {input_file_path} -ar 16000 -ac 1 -bits_per_raw_sample 16 -vn "
30
  f"{output_file_path}"
31
  )
32
  subprocess.call(command, shell=True)
33
 
34
- def parse_transcription_with_lm(wav_file):
35
- input_values = read_file_and_process(wav_file)
36
-
37
- # with torch.no_grad():
38
- # logits = model(**input_values).logits[0].cpu().numpy()
39
- # print(logits)
40
- # int_result = processor_with_LM.decode(logits = logits, output_word_offsets=False,
41
- # beam_width=128
42
- # )
43
- # print(int_result)
44
- # transcription = int_result.text.replace('<s>','')
45
-
46
-
47
- with torch.no_grad():
48
- logits = model(**input_values).logits
49
 
 
50
  result = processor_with_LM.batch_decode(logits.cpu().numpy())
51
  text = result.text
52
  transcription = text[0].replace('<s>','')
53
  return transcription
54
 
55
-
56
- def read_file_and_process(wav_file):
57
- filename = wav_file.split('.')[0]
58
- resampler(wav_file, filename + "16k.wav")
59
- speech, _ = sf.read(filename + "16k.wav")
60
- inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
61
-
62
- return inputs
63
 
64
  def parse(wav_file, applyLM):
 
 
 
 
65
  if applyLM:
66
  return parse_transcription_with_lm(wav_file)
67
  else:
68
  return parse_transcription(wav_file)
69
 
70
- def parse_transcription(wav_file):
71
- input_values = read_file_and_process(wav_file)
72
- with torch.no_grad():
73
- logits = model(**input_values).logits
74
- #logits = model(input_values).logits
75
- predicted_ids = torch.argmax(logits, dim=-1)
76
-
77
- transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
78
- return transcription
79
 
80
  model_id = "Harveenchadha/vakyansh-wav2vec2-hindi-him-4200"
81
  processor = Wav2Vec2Processor.from_pretrained(model_id)
82
  processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
83
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
84
-
85
 
86
 
87
  input_ = gr.Audio(source="microphone", type="filepath")
88
- #input_ = gr.inputs.Audio(source="microphone", type="numpy")
89
  txtbox = gr.Textbox(
90
  label="Output from model will appear here:",
91
  lines=5
92
  )
93
-
94
  chkbox = gr.Checkbox(label="Apply LM", value=False)
95
 
 
96
  gr.Interface(parse, inputs = [input_, chkbox], outputs=txtbox,
97
  streaming=True, interactive=True,
98
  analytics_enabled=False, show_tips=False, enable_queue=True).launch(inline=False);
 
2
  import torch
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor,Wav2Vec2ProcessorWithLM
4
  import gradio as gr
 
5
  import sox
6
  import subprocess
7
 
 
 
 
 
 
 
 
8
 
9
+ def read_file_and_process(wav_file):
10
+ filename = wav_file.split('.')[0]
11
+ filename_16k = filename + "16k.wav"
12
+ resampler(wav_file, filename_16k)
13
+ speech, _ = sf.read(filename_16k)
14
+ inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
15
+
16
+ return inputs
17
 
18
 
19
  def resampler(input_file_path, output_file_path):
 
 
20
  command = (
21
  f"ffmpeg -hide_banner -loglevel panic -i {input_file_path} -ar 16000 -ac 1 -bits_per_raw_sample 16 -vn "
22
  f"{output_file_path}"
23
  )
24
  subprocess.call(command, shell=True)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def parse_transcription_with_lm(logits):
28
  result = processor_with_LM.batch_decode(logits.cpu().numpy())
29
  text = result.text
30
  transcription = text[0].replace('<s>','')
31
  return transcription
32
 
33
+ def parse_transcription(logits):
34
+ predicted_ids = torch.argmax(logits, dim=-1)
35
+ transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
36
+ return transcription
 
 
 
 
37
 
38
  def parse(wav_file, applyLM):
39
+ input_values = read_file_and_process(wav_file)
40
+ with torch.no_grad():
41
+ logits = model(**input_values).logits
42
+
43
  if applyLM:
44
  return parse_transcription_with_lm(wav_file)
45
  else:
46
  return parse_transcription(wav_file)
47
 
 
 
 
 
 
 
 
 
 
48
 
49
  model_id = "Harveenchadha/vakyansh-wav2vec2-hindi-him-4200"
50
  processor = Wav2Vec2Processor.from_pretrained(model_id)
51
  processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
52
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
 
53
 
54
 
55
  input_ = gr.Audio(source="microphone", type="filepath")
 
56
  txtbox = gr.Textbox(
57
  label="Output from model will appear here:",
58
  lines=5
59
  )
 
60
  chkbox = gr.Checkbox(label="Apply LM", value=False)
61
 
62
+
63
  gr.Interface(parse, inputs = [input_, chkbox], outputs=txtbox,
64
  streaming=True, interactive=True,
65
  analytics_enabled=False, show_tips=False, enable_queue=True).launch(inline=False);