File size: 3,404 Bytes
bbffab5
ac00417
 
 
 
 
 
 
a602a35
b3834f9
a602a35
 
 
 
ac00417
bbffab5
ac00417
b3834f9
 
2f0576c
b3834f9
ac00417
bbffab5
 
 
 
ac00417
bbffab5
 
 
 
ac00417
bbffab5
 
 
 
 
 
 
 
b3834f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a0d48
b3834f9
 
 
ac00417
 
 
 
 
 
 
bbffab5
ac00417
 
 
84a0d48
ac00417
 
 
 
 
 
 
 
 
 
 
 
 
 
a602a35
2425fce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#from reazonspeech.nemo.asr import load_model, transcribe, audio_from_numpy
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File
import uvicorn
import numpy as np
import io
from pydub import AudioSegment
import time
import logging
from transformers import WhisperProcessor, WhisperForConditionalGeneration

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "mps" if torch.backends.mps.is_available() else device

# model = load_model(device)

processor = WhisperProcessor.from_pretrained("Ivydata/whisper-small-japanese")
model = WhisperForConditionalGeneration.from_pretrained("Ivydata/whisper-small-japanese").to(device)

# def transcribe_audio(audio_data_bytes):
#     try:
#         start_time = time.time()
#         audio_segment = AudioSegment.from_mp3(io.BytesIO(audio_data_bytes))
        
#         # Get audio data as numpy array
#         audio_data_int16 = np.array(audio_segment.get_array_of_samples())
#         # Convert to float32 normalized to [-1, 1]
#         audio_data_float32 = audio_data_int16.astype(np.float32) / 32768.0
        
#         # Process with reazonspeech
#         audio = audio_from_numpy(audio_data_float32, samplerate=audio_segment.frame_rate)
#         result = transcribe(model, audio)
#         end_time = time.time()
#         print(f"Time taken: {end_time - start_time} seconds")
#         return result
#     except Exception as e:
#         raise HTTPException(status_code=500, detail=str(e))
    
def transcribe_whisper(audio_data_bytes):
    try:
        start_time = time.time()
        audio_segment = AudioSegment.from_mp3(io.BytesIO(audio_data_bytes))
        
        # Get audio data as numpy array
        audio_data_int16 = np.array(audio_segment.get_array_of_samples())
        # Convert to float32 normalized to [-1, 1]
        audio_data_float32 = audio_data_int16.astype(np.float32) / 32768.0
        
        # Process with whisper
        input_features = processor(audio=audio_data_float32, 
                                sampling_rate=audio_segment.frame_rate, 
                                return_tensors="pt").input_features.to(device)
                                
        predicted_ids = model.generate(input_features=input_features)

        result = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        resultText = result[0] if isinstance(result, list) and len(result) > 0 else str(result)
        end_time = time.time()
        print(f"Time taken: {end_time - start_time} seconds")
        return resultText
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    

app = FastAPI()

@app.post("/transcribe")
async def transcribe_endpoint(file: UploadFile = File(...)):
    audio_data = await file.read()
    try:
        result = transcribe_whisper(audio_data)
        return {
            "result": [
                {
                    "text": result
                }
            ]
        }
    except HTTPException as e:
        return {
            "result": [
                {
                    "text": "γ‚¨γƒ©γƒΌγŒη™Ίη”Ÿγ—γΎγ—γŸ, もう一度試してください",
                }
            ]
        }
    

if __name__ == "__main__":
    logger.info(f"Model loaded on {device}")
    uvicorn.run(app, host="0.0.0.0", port=7860)