stt / serve.py
komyu1227's picture
fix somecode
84a0d48
#from reazonspeech.nemo.asr import load_model, transcribe, audio_from_numpy
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File
import uvicorn
import numpy as np
import io
from pydub import AudioSegment
import time
import logging
from transformers import WhisperProcessor, WhisperForConditionalGeneration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "mps" if torch.backends.mps.is_available() else device
# model = load_model(device)
processor = WhisperProcessor.from_pretrained("Ivydata/whisper-small-japanese")
model = WhisperForConditionalGeneration.from_pretrained("Ivydata/whisper-small-japanese").to(device)
# def transcribe_audio(audio_data_bytes):
# try:
# start_time = time.time()
# audio_segment = AudioSegment.from_mp3(io.BytesIO(audio_data_bytes))
# # Get audio data as numpy array
# audio_data_int16 = np.array(audio_segment.get_array_of_samples())
# # Convert to float32 normalized to [-1, 1]
# audio_data_float32 = audio_data_int16.astype(np.float32) / 32768.0
# # Process with reazonspeech
# audio = audio_from_numpy(audio_data_float32, samplerate=audio_segment.frame_rate)
# result = transcribe(model, audio)
# end_time = time.time()
# print(f"Time taken: {end_time - start_time} seconds")
# return result
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
def transcribe_whisper(audio_data_bytes):
try:
start_time = time.time()
audio_segment = AudioSegment.from_mp3(io.BytesIO(audio_data_bytes))
# Get audio data as numpy array
audio_data_int16 = np.array(audio_segment.get_array_of_samples())
# Convert to float32 normalized to [-1, 1]
audio_data_float32 = audio_data_int16.astype(np.float32) / 32768.0
# Process with whisper
input_features = processor(audio=audio_data_float32,
sampling_rate=audio_segment.frame_rate,
return_tensors="pt").input_features.to(device)
predicted_ids = model.generate(input_features=input_features)
result = processor.batch_decode(predicted_ids, skip_special_tokens=True)
resultText = result[0] if isinstance(result, list) and len(result) > 0 else str(result)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
return resultText
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
app = FastAPI()
@app.post("/transcribe")
async def transcribe_endpoint(file: UploadFile = File(...)):
audio_data = await file.read()
try:
result = transcribe_whisper(audio_data)
return {
"result": [
{
"text": result
}
]
}
except HTTPException as e:
return {
"result": [
{
"text": "γ‚¨γƒ©γƒΌγŒη™Ίη”Ÿγ—γΎγ—γŸ, もう一度試してください",
}
]
}
if __name__ == "__main__":
logger.info(f"Model loaded on {device}")
uvicorn.run(app, host="0.0.0.0", port=7860)