| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| |
|
| | from pydantic import BaseModel |
| | import librosa |
| | import torch |
| | import base64 |
| | import io |
| | import logging |
| | import numpy as np |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | class AudioRequest(BaseModel): |
| | audio_data: str |
| | sample_rate: int |
| |
|
| | class AudioResponse(BaseModel): |
| | audio_data: str |
| | text: str = "" |
| |
|
| | |
| | INITIALIZATION_STATUS = { |
| | "model_loaded": False, |
| | "error": None |
| | } |
| |
|
| |
|
| | |
| | class Model: |
| | def __init__(self): |
| | self.model = model = AutoModel.from_pretrained( |
| | './models/checkpoint', |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16, |
| | attn_implementation='sdpa' |
| | ) |
| | model = model.eval().cuda() |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | './models/checkpoint', |
| | trust_remote_code=True |
| | ) |
| | |
| | |
| | model.init_tts() |
| | model.tts.float() |
| | |
| | self.model_in_sr = 16000 |
| | self.model_out_sr = 24000 |
| | self.ref_audio, _ = librosa.load('./ref_audios/female_example.wav', sr=self.model_in_sr, mono=True) |
| | self.sys_prompt = model.get_sys_prompt(ref_audio=self.ref_audio, mode='audio_assistant', language='en') |
| |
|
| | |
| | audio_data = librosa.load('./ref_audios/male_example.wav', sr=self.model_in_sr, mono=True)[0] |
| | _ = self.inference(audio_data, self.model_in_sr) |
| | |
| | def inference(self, audio_np, input_audio_sr): |
| | if input_audio_sr != self.model_in_sr: |
| | audio_np = librosa.resample(audio_np, orig_sr=input_audio_sr, target_sr=self.model_in_sr) |
| | |
| | user_question = {'role': 'user', 'content': [audio_np]} |
| |
|
| | |
| | msgs = [self.sys_prompt, user_question] |
| | res = self.model.chat( |
| | msgs=msgs, |
| | tokenizer=self.tokenizer, |
| | sampling=True, |
| | max_new_tokens=128, |
| | use_tts_template=True, |
| | generate_audio=True, |
| | temperature=0.3, |
| | ) |
| | audio = res["audio_wav"].cpu().numpy() |
| |
|
| | if self.model_out_sr != input_audio_sr: |
| | audio = librosa.resample(audio, orig_sr=self.model_out_sr, target_sr=input_audio_sr) |
| | |
| | return audio, res["text"] |
| |
|
| | def initialize_model(): |
| | """Initialize the MiniCPM model""" |
| | global model, INITIALIZATION_STATUS |
| | try: |
| | logger.info("Initializing model...") |
| | model = Model() |
| |
|
| | INITIALIZATION_STATUS["model_loaded"] = True |
| | logger.info("MiniCPM model initialized successfully") |
| | return True |
| | except Exception as e: |
| | INITIALIZATION_STATUS["error"] = str(e) |
| | logger.error(f"Failed to initialize model: {e}") |
| | return False |
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Initialize model on startup""" |
| | initialize_model() |
| |
|
| | @app.get("/api/v1/health") |
| | def health_check(): |
| | """Health check endpoint""" |
| | status = { |
| | "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
| | "model_loaded": INITIALIZATION_STATUS["model_loaded"], |
| | "error": INITIALIZATION_STATUS["error"] |
| | } |
| | return status |
| |
|
| | @app.post("/api/v1/inference") |
| | async def inference(request: AudioRequest) -> AudioResponse: |
| | """Run inference with MiniCPM model""" |
| | if not INITIALIZATION_STATUS["model_loaded"]: |
| | raise HTTPException( |
| | status_code=503, |
| | detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
| | ) |
| |
|
| | try: |
| | |
| | audio_bytes = base64.b64decode(request.audio_data) |
| | audio_np = np.load(io.BytesIO(audio_bytes)).flatten() |
| |
|
| | |
| | import time |
| | start = time.time() |
| | print(f"starting inference with audio length {audio_np.shape}") |
| | audio_response, text_response = model.inference(audio_np, request.sample_rate) |
| | print(f"inference took {time.time() - start} seconds") |
| |
|
| | |
| | buffer = io.BytesIO() |
| | np.save(buffer, audio_response) |
| | audio_b64 = base64.b64encode(buffer.getvalue()).decode() |
| |
|
| | return AudioResponse( |
| | audio_data=audio_b64, |
| | text=text_response |
| | ) |
| |
|
| | except Exception as e: |
| | logger.error(f"Inference failed: {str(e)}") |
| | raise HTTPException( |
| | status_code=500, |
| | detail=str(e) |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|