from fastapi import FastAPI, HTTPException
# from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware

# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# from pydantic import BaseModel
import librosa

# import librosa
import torch
import base64

# import base64
import io
# import io
import logging
import numpy as np

# import numpy as np
from transformers import AutoModel, AutoTokenizer


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class AudioRequest(BaseModel):
    audio_data: str
    sample_rate: int

class AudioResponse(BaseModel):
    audio_data: str
    text: str = ""

# Model initialization status
INITIALIZATION_STATUS = {
    "model_loaded": False,
    "error": None
}


# Global model and tokenizer instances
class Model:
    def __init__(self):
        self.model = model = AutoModel.from_pretrained(
            './models/checkpoint',
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            attn_implementation='sdpa'
        )
        model = model.eval().cuda()
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            './models/checkpoint',
            trust_remote_code=True
        )
        
        # Initialize TTS
        model.init_tts()
        model.tts.float()  # Convert TTS to float32 if needed
        
        self.model_in_sr = 16000
        self.model_out_sr = 24000
        self.ref_audio, _ = librosa.load('./ref_audios/female_example.wav', sr=self.model_in_sr, mono=True) # load the reference audio
        self.sys_prompt = model.get_sys_prompt(ref_audio=self.ref_audio, mode='audio_assistant', language='en') 

        # warmup
        audio_data = librosa.load('./ref_audios/male_example.wav', sr=self.model_in_sr, mono=True)[0]
        _ = self.inference(audio_data, self.model_in_sr)
    
    def inference(self, audio_np, input_audio_sr):
        if input_audio_sr != self.model_in_sr:
            audio_np = librosa.resample(audio_np, orig_sr=input_audio_sr, target_sr=self.model_in_sr)
        
        user_question = {'role': 'user', 'content': [audio_np]}

        # round one
        msgs = [self.sys_prompt, user_question]
        res = self.model.chat(
            msgs=msgs,
            tokenizer=self.tokenizer,
            sampling=True,
            max_new_tokens=128,
            use_tts_template=True,
            generate_audio=True,
            temperature=0.3,
        )
        audio = res["audio_wav"].cpu().numpy()

        if self.model_out_sr != input_audio_sr:
            audio = librosa.resample(audio, orig_sr=self.model_out_sr, target_sr=input_audio_sr)
        
        return audio, res["text"]

def initialize_model():
    """Initialize the MiniCPM model"""
    global model, INITIALIZATION_STATUS
    try:
        logger.info("Initializing model...")
        model = Model()

        INITIALIZATION_STATUS["model_loaded"] = True
        logger.info("MiniCPM model initialized successfully")
        return True
    except Exception as e:
        INITIALIZATION_STATUS["error"] = str(e)
        logger.error(f"Failed to initialize model: {e}")
        return False

@app.on_event("startup")
async def startup_event():
    """Initialize model on startup"""
    initialize_model()

@app.get("/api/v1/health")
def health_check():
    """Health check endpoint"""
    status = {
        "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
        "model_loaded": INITIALIZATION_STATUS["model_loaded"],
        "error": INITIALIZATION_STATUS["error"]
    }
    return status

@app.post("/api/v1/inference")
async def inference(request: AudioRequest) -> AudioResponse:
    """Run inference with MiniCPM model"""
    if not INITIALIZATION_STATUS["model_loaded"]:
        raise HTTPException(
            status_code=503,
            detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
        )

    try:
        # Decode audio data from base64
        audio_bytes = base64.b64decode(request.audio_data)
        audio_np = np.load(io.BytesIO(audio_bytes)).flatten()

        # Generate response
        import time
        start = time.time()
        print(f"starting inference with audio length {audio_np.shape}")
        audio_response, text_response = model.inference(audio_np, request.sample_rate)
        print(f"inference took {time.time() - start} seconds")

        # If we got audio, save it and encode to base64
        buffer = io.BytesIO()
        np.save(buffer, audio_response)
        audio_b64 = base64.b64encode(buffer.getvalue()).decode()

        return AudioResponse(
            audio_data=audio_b64,
            text=text_response
        )

    except Exception as e:
        logger.error(f"Inference failed: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=str(e)
        )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)