File size: 3,213 Bytes
fd6c122
2109067
0d05344
 
 
 
 
2109067
 
555a546
2109067
8a87bd9
0d05344
 
2109067
 
0d05344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb7e4d
2a6735f
 
0d05344
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6c122
 
 
 
 
 
 
 
 
 
2109067
fd6c122
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import base64
import os
from flask import Flask, request, jsonify
from pydub import AudioSegment
import whisper
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

# Define cache directory
os.environ['HF_HOME'] = '/app/cache'

# Load the Whisper model with a specified cache directory
whisper_model = whisper.load_model("base", download_root="/app/cache")

# Load the translation model and tokenizer
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache")
translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache")

def preprocess_audio(audio_path):
    """Convert audio to 16kHz mono WAV format."""
    audio = AudioSegment.from_file(audio_path)
    audio = audio.set_frame_rate(16000).set_channels(1)  # Set to 16kHz and mono
    processed_path = f"{audio_path}_processed.wav"
    audio.export(processed_path, format="wav")
    return processed_path

def transcribe_audio(audio_path, source_language=None):
    """Transcribe audio using Whisper with an optional source language."""
    options = {"language": source_language} if source_language else {}
    result = whisper_model.transcribe(audio_path, **options)
    return result['text']

def translate_text(text, source_lang="en", target_lang="hi"):
    """Translate text using Facebook's M2M100 model."""
    tokenizer.src_lang = source_lang
    inputs = tokenizer(text, return_tensors="pt")
    translated_tokens = translation_model.generate(
        **inputs,
        forced_bos_token_id=tokenizer.get_lang_id(target_lang)
    )
    return tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

def handle_request(audio_base64, source_lang, target_lang):
    """Handle audio translation request."""
    temp_dir = "/app/temp"
    os.makedirs(temp_dir, exist_ok=True)  # Ensure directory exists
    audio_file_path = os.path.join(temp_dir, "temp_audio.wav")
    # Decode the base64 audio
    with open(audio_file_path, "wb") as audio_file:
        audio_file.write(base64.b64decode(audio_base64))

    # Process the audio file
    processed_audio_file_name = preprocess_audio(audio_file_path)
    spoken_text = transcribe_audio(processed_audio_file_name, source_lang)
    translated_text = translate_text(spoken_text, source_lang, target_lang)

    # Clean up temporary files
    os.remove(processed_audio_file_name)
    os.remove(audio_file_path)

    return {"transcribed_text": spoken_text, "translated_text": translated_text}

# Flask for handling external POST requests
app = Flask(__name__)

@app.route('/translate', methods=['POST'])
def translate():
    """API endpoint for handling audio translation."""
    data = request.json
    if 'audio' not in data or 'source_lang' not in data or 'target_lang' not in data:
        return jsonify({"error": "Invalid request format"}), 400
    
    audio_base64 = data['audio']
    source_lang = data['source_lang']
    target_lang = data['target_lang']

    # Call the handle_request function to process the request
    response = handle_request(audio_base64, source_lang, target_lang)
    return jsonify(response)

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)