mizoru commited on
Commit
2b10872
1 Parent(s): f67f7fd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
3
+
4
+ MODELS = {
5
+ "Tatar": {"model_id": "sammy786/wav2vec2-xlsr-tatar", "has_lm": False},
6
+ "Chuvash": {"model_id": "sammy786/wav2vec2-xlsr-chuvash", "has_lm": False}
7
+ }
8
+
9
+ CACHED_MODELS_BY_ID = {}
10
+
11
+ LANGUAGES = sorted(MODELS.keys())
12
+
13
+ def run(input_file, language, decoding_type, history):
14
+
15
+ #logger.info(f"Running ASR {language}-{model_size}-{decoding_type} for {input_file}")
16
+
17
+ model = MODELS.get(language, None)
18
+
19
+
20
+ if decoding_type == "LM" and not model["has_lm"]:
21
+ history.append({
22
+ "error_message": f"LM not available for {language} language :("
23
+ })
24
+ else:
25
+
26
+ # model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
27
+ model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None)
28
+ if model_instance is None:
29
+ model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
30
+ CACHED_MODELS_BY_ID[model["model_id"]] = model_instance
31
+
32
+ if decoding_type == "LM":
33
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"])
34
+ asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
35
+ feature_extractor=processor.feature_extractor, decoder=processor.decoder)
36
+ else:
37
+ processor = Wav2Vec2Processor.from_pretrained(model["model_id"])
38
+ asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
39
+ feature_extractor=processor.feature_extractor, decoder=None)
40
+
41
+ transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"]
42
+
43
+
44
+
45
+ return transcription
46
+
47
+ gr.Interface(
48
+ run,
49
+ inputs=[
50
+ gr.Audio(source="microphone", type="filepath", label="Record something..."),
51
+ gr.Radio(label="Language", choices=LANGUAGES),
52
+ gr.Radio(label="Decoding type", choices=["greedy", "LM"]),
53
+ # gr.inputs.Radio(label="Model size", choices=["300M", "1B"]),
54
+ "state"
55
+ ],
56
+ outputs=[
57
+ gr.TextBox
58
+ ],
59
+ allow_screenshot=False,
60
+ allow_flagging="never",
61
+ theme="grass"
62
+ ).launch(enable_queue=True)