abnerh commited on
Commit
0cc2cbd
1 Parent(s): 79eccc5

german and spanish

Browse files
Files changed (3) hide show
  1. app.py +38 -14
  2. clean_text.py +39 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,18 +1,31 @@
1
  import os, sys, re
2
  import shutil
3
- import argparse
4
  import subprocess
5
  import soundfile
6
  from process_audio import segment_audio
7
  from write_srt import write_to_file
8
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Tokenizer
 
9
  import torch
10
  import gradio as gr
11
 
12
 
13
- model = "facebook/wav2vec2-large-960h-lv60-self"
14
- tokenizer = Wav2Vec2Tokenizer.from_pretrained(model)
15
- asr_model = Wav2Vec2ForCTC.from_pretrained(model)#.to('cuda')
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Line count for SRT file
18
  line_count = 0
@@ -34,18 +47,23 @@ def transcribe_audio(tokenizer, asr_model, audio_file, file_handle):
34
 
35
 
36
  infered_text = tokenizer.batch_decode(prediction)[0].lower()
37
- infered_text = re.sub(r' ', ' ', infered_text)
38
- infered_text = re.sub(r'\bi\s', 'I ', infered_text)
39
- infered_text = re.sub(r'\si$', ' I', infered_text)
40
- infered_text = re.sub(r'i\'', 'I\'', infered_text)
41
-
42
- limits = audio_file.split(os.sep)[-1][:-4].split("_")[-1].split("-")
43
-
44
  if len(infered_text) > 1:
 
 
 
 
 
 
 
 
 
45
  line_count += 1
46
  write_to_file(file_handle, infered_text, line_count, limits)
 
 
 
47
 
48
- def get_subs(input_file):
49
  # Get directory for audio
50
  base_directory = os.getcwd()
51
  audio_directory = os.path.join(base_directory, "audio")
@@ -71,6 +89,11 @@ def get_subs(input_file):
71
  file_handle.seek(0)
72
  for file in sort_alphanumeric(os.listdir(audio_directory)):
73
  audio_segment_path = os.path.join(audio_directory, file)
 
 
 
 
 
74
  if audio_segment_path.split(os.sep)[-1] != audio_file.split(os.sep)[-1]:
75
  transcribe_audio(tokenizer, asr_model, audio_segment_path, file_handle)
76
 
@@ -84,7 +107,8 @@ gradio_ui = gr.Interface(
84
  fn=get_subs,
85
  title="Video to Subtitle",
86
  description="Get subtitles (SRT file) for your videos. Inference speed is about 10s/per 1min of video BUT the speed of uploading your video depends on your internet connection.",
87
- inputs=gr.inputs.Video(label="Upload Video File"),
 
88
  outputs=gr.outputs.File(label="Auto-Transcript")
89
  )
90
 
1
  import os, sys, re
2
  import shutil
 
3
  import subprocess
4
  import soundfile
5
  from process_audio import segment_audio
6
  from write_srt import write_to_file
7
+ from clean_text import clean_english, clean_german, clean_spanish
8
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
  import torch
10
  import gradio as gr
11
 
12
 
13
+ english_model = "facebook/wav2vec2-large-960h-lv60-self"
14
+ english_tokenizer = Wav2Vec2Processor.from_pretrained(english_model)
15
+ english_asr_model = Wav2Vec2ForCTC.from_pretrained(english_model)
16
+
17
+ german_model = "jonatasgrosman/wav2vec2-large-xlsr-53-german"
18
+ german_tokenizer = Wav2Vec2Processor.from_pretrained(german_model)
19
+ german_asr_model = Wav2Vec2ForCTC.from_pretrained(german_model)
20
+
21
+ spanish_model = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
22
+ spanish_tokenizer = Wav2Vec2Processor.from_pretrained(spanish_model)
23
+ spanish_asr_model = Wav2Vec2ForCTC.from_pretrained(spanish_model)
24
+
25
+ # Get German corpus and update nltk
26
+ command = ["python", "-m", "textblob.download_corpora"]
27
+ subprocess.run(command)
28
+
29
 
30
  # Line count for SRT file
31
  line_count = 0
47
 
48
 
49
  infered_text = tokenizer.batch_decode(prediction)[0].lower()
 
 
 
 
 
 
 
50
  if len(infered_text) > 1:
51
+ if lang == 'english':
52
+ infered_text = clean_english(infered_text)
53
+ elif lang == 'german':
54
+ infered_text = clean_german(infered_text)
55
+ elif lang == 'spanish':
56
+ infered_text = clean_spanish(infered_text)
57
+
58
+ print(infered_text)
59
+ limits = audio_file.split(os.sep)[-1][:-4].split("_")[-1].split("-")
60
  line_count += 1
61
  write_to_file(file_handle, infered_text, line_count, limits)
62
+ else:
63
+ infered_text = ''
64
+
65
 
66
+ def get_subs(input_file, language):
67
  # Get directory for audio
68
  base_directory = os.getcwd()
69
  audio_directory = os.path.join(base_directory, "audio")
89
  file_handle.seek(0)
90
  for file in sort_alphanumeric(os.listdir(audio_directory)):
91
  audio_segment_path = os.path.join(audio_directory, file)
92
+ global lang
93
+ lang = language.lower()
94
+ tokenizer = globals()[lang+'_tokenizer']
95
+ asr_model = globals()[lang+'_asr_model']
96
+
97
  if audio_segment_path.split(os.sep)[-1] != audio_file.split(os.sep)[-1]:
98
  transcribe_audio(tokenizer, asr_model, audio_segment_path, file_handle)
99
 
107
  fn=get_subs,
108
  title="Video to Subtitle",
109
  description="Get subtitles (SRT file) for your videos. Inference speed is about 10s/per 1min of video BUT the speed of uploading your video depends on your internet connection.",
110
+ inputs=[gr.inputs.Video(label="Upload Video File"),
111
+ gr.inputs.Radio(label="Choose Language", choices=['English', 'German', 'Spanish'])],
112
  outputs=gr.outputs.File(label="Auto-Transcript")
113
  )
114
 
clean_text.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, string
2
+ import subprocess
3
+ from textblob_de import TextBlobDE as TextBlob
4
+
5
+
6
+ def clean_english(text):
7
+ clean_text = re.sub(r' ', ' ', text)
8
+ clean_text = re.sub(r'\bi\s', 'I ', clean_text)
9
+ clean_text = re.sub(r'\si$', ' I', clean_text)
10
+ clean_text = re.sub(r'i\'', 'I\'', clean_text)
11
+
12
+ return clean_text
13
+
14
+ def clean_german(text):
15
+ text = text.translate(str.maketrans('', '', string.punctuation))
16
+
17
+ # Tokenize German text
18
+ blob = TextBlob(text)
19
+ pos = blob.tags
20
+
21
+ # Get nouns and capitalize
22
+ nouns = {}
23
+ for idx in pos:
24
+ if idx[1] == 'NN' and len(idx[0]) > 1:
25
+ nouns[idx[0]] = idx[0].capitalize()
26
+
27
+ if len(nouns) != 0:
28
+ pattern = re.compile("|".join(nouns.keys()))
29
+ text = pattern.sub(lambda m: nouns[re.escape(m.group(0))], text)
30
+
31
+ return text
32
+
33
+
34
+ def clean_spanish(text):
35
+ clean_text = text.translate(str.maketrans('', '', string.punctuation))
36
+ clean_text = re.sub(r' ', ' ', clean_text)
37
+
38
+ return clean_text
39
+
requirements.txt CHANGED
@@ -3,3 +3,4 @@ transformers
3
  torch
4
  gradio
5
  auditok
 
3
  torch
4
  gradio
5
  auditok
6
+ textblob_de