File size: 2,173 Bytes
9616027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
torch.set_float32_matmul_precision('high')

from flask import Flask, send_from_directory, request, Response
import os
import base64
import numpy as np
from inference import OmniInference
import io

app = Flask(__name__)

# Initialize OmniInference
try:
    print("Initializing OmniInference...")
    omni = OmniInference()
    print("OmniInference initialized successfully.")
except Exception as e:
    print(f"Error initializing OmniInference: {str(e)}")
    raise

@app.route('/')
def serve_html():
    return send_from_directory('.', 'webui/omni_html_demo.html')

@app.route('/chat', methods=['POST'])
def chat():
    try:
        audio_data = request.json['audio']
        if not audio_data:
            return "No audio data received", 400

        # Check if the audio_data contains the expected base64 prefix
        if ',' in audio_data:
            audio_bytes = base64.b64decode(audio_data.split(',')[1])
        else:
            audio_bytes = base64.b64decode(audio_data)

        # Save audio to a temporary file
        temp_audio_path = 'temp_audio.wav'
        with open(temp_audio_path, 'wb') as f:
            f.write(audio_bytes)

        # Generate response using OmniInference
        try:
            response_generator = omni.run_AT_batch_stream(temp_audio_path)

            # Concatenate all audio chunks
            all_audio = b''
            for audio_chunk in response_generator:
                all_audio += audio_chunk

            # Clean up temporary file
            os.remove(temp_audio_path)

            return Response(all_audio, mimetype='audio/wav')
        except Exception as inner_e:
            print(f"Error in OmniInference processing: {str(inner_e)}")
            return f"An error occurred during audio processing: {str(inner_e)}", 500
        finally:
            # Ensure temporary file is removed even if an error occurs
            if os.path.exists(temp_audio_path):
                os.remove(temp_audio_path)

    except Exception as e:
        print(f"Error in chat endpoint: {str(e)}")
        return f"An error occurred: {str(e)}", 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)