Tmeena commited on
Commit
0d05344
·
verified ·
1 Parent(s): 547c5a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from flask import Flask, request, jsonify
4
+ from pydub import AudioSegment
5
+ import whisper
6
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
7
+
8
+ # Load the Whisper model
9
+ whisper_model = whisper.load_model("base")
10
+
11
+ # Load the translation model and tokenizer
12
+ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
13
+ translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
14
+
15
+ def preprocess_audio(audio_path):
16
+ """Convert audio to 16kHz mono WAV format."""
17
+ audio = AudioSegment.from_file(audio_path)
18
+ audio = audio.set_frame_rate(16000).set_channels(1) # Set to 16kHz and mono
19
+ processed_path = f"{audio_path}_processed.wav"
20
+ audio.export(processed_path, format="wav")
21
+ return processed_path
22
+
23
+ def transcribe_audio(audio_path, source_language=None):
24
+ """Transcribe audio using Whisper with an optional source language."""
25
+ options = {"language": source_language} if source_language else {}
26
+ result = whisper_model.transcribe(audio_path, **options)
27
+ return result['text']
28
+
29
+ def translate_text(text, source_lang="en", target_lang="hi"):
30
+ """Translate text using Facebook's M2M100 model."""
31
+ tokenizer.src_lang = source_lang
32
+ inputs = tokenizer(text, return_tensors="pt")
33
+ translated_tokens = translation_model.generate(
34
+ **inputs,
35
+ forced_bos_token_id=tokenizer.get_lang_id(target_lang)
36
+ )
37
+ return tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
38
+
39
+ def handle_request(audio_base64, source_lang, target_lang):
40
+ """Handle audio translation request."""
41
+ audio_file_path = "temp_audio.wav"
42
+ # Decode the base64 audio
43
+ with open(audio_file_path, "wb") as audio_file:
44
+ audio_file.write(base64.b64decode(audio_base64))
45
+
46
+ # Process the audio file
47
+ processed_audio_file_name = preprocess_audio(audio_file_path)
48
+ spoken_text = transcribe_audio(processed_audio_file_name, source_lang)
49
+ translated_text = translate_text(spoken_text, source_lang, target_lang)
50
+
51
+ # Clean up temporary files
52
+ os.remove(processed_audio_file_name)
53
+ os.remove(audio_file_path)
54
+
55
+ return {"transcribed_text": spoken_text, "translated_text": translated_text}
56
+
57
+ # Flask for handling external POST requests
58
+ app = Flask(__name__)
59
+
60
+ @app.route('/translate', methods=['POST'])
61
+ def translate():
62
+ """API endpoint for handling audio translation."""
63
+ data = request.json
64
+ if 'audio' not in data or 'source_lang' not in data or 'target_lang' not in data:
65
+ return jsonify({"error": "Invalid request format"}), 400
66
+
67
+ audio_base64 = data['audio']
68
+ source_lang = data['source_lang']
69
+ target_lang = data['target_lang']
70
+
71
+ # Call the handle_request function to process the request
72
+ response = handle_request(audio_base64, source_lang, target_lang)
73
+ return jsonify(response)
74
+
75
+ if __name__ == "__main__":
76
+ app.run(host='0.0.0.0', port=7860)