KIMOSSINO commited on
Commit
6bea046
·
verified ·
1 Parent(s): 08cc981

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -42
app.py CHANGED
@@ -1,13 +1,8 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- import uvicorn
3
- from pydantic import BaseModel
4
  import whisper
5
  from transformers import MarianMTModel, MarianTokenizer
6
  import subprocess
7
  import os
8
- from pathlib import Path
9
-
10
- app = FastAPI()
11
 
12
  # Load models
13
  def load_models():
@@ -23,66 +18,84 @@ def load_models():
23
  lang: MarianTokenizer.from_pretrained(f"Helsinki-NLP/opus-mt-{lang}-en")
24
  for lang in translation_models.keys()
25
  }
26
- return translation_models, translation_tokenizers
27
 
28
- translation_models, translation_tokenizers = load_models()
29
 
30
- # Whisper endpoint
31
- @app.post("/transcribe")
32
- async def transcribe(file: UploadFile = File(...), language: str = "en"):
33
  try:
34
- # Save the file temporarily
35
- temp_file = f"temp/{file.filename}"
36
- Path("temp").mkdir(parents=True, exist_ok=True)
37
- with open(temp_file, "wb") as f:
38
- f.write(await file.read())
39
-
40
- # Transcription using Whisper
41
- result = whisper_model.transcribe(temp_file, language=language)
42
  transcription = result["text"]
43
- os.remove(temp_file) # Clean up
44
- return {"success": True, "transcription": transcription}
45
-
46
  except Exception as e:
47
- return {"success": False, "error": str(e)}
48
 
49
- # Translation endpoint
50
- @app.post("/translate")
51
- async def translate(text: str, source_lang: str, target_lang: str):
52
  try:
53
  if source_lang not in translation_models or target_lang != "en":
54
- return {"success": False, "error": "Unsupported language."}
55
 
56
- # Tokenize and translate
57
  tokenizer = translation_tokenizers[source_lang]
58
  model = translation_models[source_lang]
59
  inputs = tokenizer(text, return_tensors="pt", padding=True)
60
  translated_tokens = model.generate(**inputs)
61
  translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
62
- return {"success": True, "translation": translated_text}
63
-
64
  except Exception as e:
65
- return {"success": False, "error": str(e)}
66
 
67
- # TTS endpoint
68
- @app.post("/tts")
69
- async def text_to_speech(text: str, speaker: str = "male", speed: str = "normal"):
70
  try:
71
  output_file = "output.wav"
72
-
73
- # Coqui TTS command
74
  tts_command = [
75
  "tts",
76
  f"--text={text}",
77
  "--model_name=tts_models/en/ljspeech/tacotron2-DCA",
78
  f"--out_path={output_file}",
79
  ]
80
- subprocess.run(tts_command)
 
 
 
81
 
82
- return {"success": True, "audio_file": output_file}
 
 
 
83
 
84
- except Exception as e:
85
- return {"success": False, "error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- if __name__ == "__main__":
88
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import gradio as gr
 
 
2
  import whisper
3
  from transformers import MarianMTModel, MarianTokenizer
4
  import subprocess
5
  import os
 
 
 
6
 
7
  # Load models
8
  def load_models():
 
18
  lang: MarianTokenizer.from_pretrained(f"Helsinki-NLP/opus-mt-{lang}-en")
19
  for lang in translation_models.keys()
20
  }
 
21
 
22
+ load_models()
23
 
24
+ # Transcribe function
25
+ def transcribe_audio(file, language="en"):
 
26
  try:
27
+ result = whisper_model.transcribe(file, language=language)
 
 
 
 
 
 
 
28
  transcription = result["text"]
29
+ return transcription
 
 
30
  except Exception as e:
31
+ return f"Error: {str(e)}"
32
 
33
+ # Translate function
34
+ def translate_text(text, source_lang, target_lang="en"):
 
35
  try:
36
  if source_lang not in translation_models or target_lang != "en":
37
+ return "Unsupported language."
38
 
 
39
  tokenizer = translation_tokenizers[source_lang]
40
  model = translation_models[source_lang]
41
  inputs = tokenizer(text, return_tensors="pt", padding=True)
42
  translated_tokens = model.generate(**inputs)
43
  translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
44
+ return translated_text
 
45
  except Exception as e:
46
+ return f"Error: {str(e)}"
47
 
48
+ # Text-to-Speech function
49
+ def text_to_speech(text, speaker="male", speed="normal"):
 
50
  try:
51
  output_file = "output.wav"
 
 
52
  tts_command = [
53
  "tts",
54
  f"--text={text}",
55
  "--model_name=tts_models/en/ljspeech/tacotron2-DCA",
56
  f"--out_path={output_file}",
57
  ]
58
+ subprocess.run(tts_command, check=True)
59
+ return output_file
60
+ except Exception as e:
61
+ return f"Error: {str(e)}"
62
 
63
+ # Gradio Interface
64
+ def tts_interface(text):
65
+ audio_file = text_to_speech(text)
66
+ return audio_file if isinstance(audio_file, str) and os.path.exists(audio_file) else None
67
 
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("### Audio Transcription, Translation, and TTS App")
70
+
71
+ # Transcription section
72
+ with gr.Row():
73
+ with gr.Column():
74
+ audio_input = gr.Audio(label="Upload Audio File", type="file")
75
+ lang_input = gr.Dropdown(["en", "es", "fr", "ar"], label="Language", value="en")
76
+ transcribe_btn = gr.Button("Transcribe")
77
+ transcription_output = gr.Textbox(label="Transcription Output")
78
+
79
+ transcribe_btn.click(transcribe_audio, inputs=[audio_input, lang_input], outputs=transcription_output)
80
+
81
+ # Translation section
82
+ with gr.Row():
83
+ with gr.Column():
84
+ text_input = gr.Textbox(label="Input Text", lines=3)
85
+ source_lang = gr.Dropdown(["en", "es", "fr", "ar"], label="Source Language", value="en")
86
+ translate_btn = gr.Button("Translate")
87
+ translation_output = gr.Textbox(label="Translation Output")
88
+
89
+ translate_btn.click(translate_text, inputs=[text_input, source_lang], outputs=translation_output)
90
+
91
+ # TTS section
92
+ with gr.Row():
93
+ with gr.Column():
94
+ tts_input = gr.Textbox(label="Text for TTS", lines=2)
95
+ tts_btn = gr.Button("Generate Audio")
96
+ tts_output = gr.Audio(label="Generated Audio")
97
+
98
+ tts_btn.click(tts_interface, inputs=tts_input, outputs=tts_output)
99
 
100
+ # Launch Gradio App
101
+ demo.launch()