import fastapi
from fastapi.responses import JSONResponse
from time import time
#from fastapi.middleware.cors import CORSMiddleware
#MODEL_PATH = "./qwen1_5-0_5b-chat-q4_0.gguf" #"./qwen1_5-0_5b-chat-q4_0.gguf"
import logging
import llama_cpp
import llama_cpp.llama_tokenizer
from pydantic import BaseModel
from fastapi import APIRouter

class GenModel(BaseModel):
    question: str
    system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English."
    temperature: float = 0.8
    seed: int = 101
    mirostat_mode: int=2
    mirostat_tau: float=4.0
    mirostat_eta: float=1.1

class ChatModel(BaseModel):
    question: list
    system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English."
    temperature: float = 0.8
    seed: int = 101
    mirostat_mode: int=2
    mirostat_tau: float=4.0
    mirostat_eta: float=1.1
llm_chat = llama_cpp.Llama.from_pretrained(
    repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
    filename="*q4_0.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
    verbose=False,
     n_ctx=1024,
     n_gpu_layers=0,
    #chat_format="llama-2"
)
llm_generate = llama_cpp.Llama.from_pretrained(
    repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
    filename="*q4_0.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
    verbose=False,
     n_ctx=4096,
     n_gpu_layers=0,
    mirostat_mode=2,
    mirostat_tau=4.0,
    mirostat_eta=1.1
    #chat_format="llama-2"
)
# Logger setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = fastapi.FastAPI(
    title="OpenGenAI",
    description="Your Excellect AI Physician")
"""
app.add_middleware(
    CORSMiddleware,
    allow_origins = ["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)
"""
llm_router = APIRouter(prefix="/llm")

@llm_router.get("/health", tags=["llm"])
def health():
    return {"status": "ok"}
    
# Chat Completion API
@llm_router.post("/chat/", tags=["llm"])
async def chat(chatm:ChatModel):
    try:
        st = time()
        output = llm_chat.create_chat_completion(
            messages = chatm.question,
            temperature = chatm.temperature,
            seed = chatm.seed,
            #stream=True
        )
        #print(output)
        et = time()
        output["time"] = et - st
        #messages.append({'role': "assistant", "content": output['choices'][0]['message']['content']})
        #print(messages)
        return output
    except Exception as e:
        logger.error(f"Error in /complete endpoint: {e}")
        return JSONResponse(
            status_code=500, content={"message": "Internal Server Error"}
        )

# Chat Completion API
@llm_router.post("/generate", tags=["llm"])
async def generate(gen:GenModel):
    gen.system = "You are an helpful medical AI assistant."
    gen.temperature = 0.5
    gen.seed = 42
    try:
        st = time()
        output = llm_generate.create_chat_completion(
            messages=[
                {"role": "system", "content": gen.system},
                {"role": "user", "content": gen.question},
            ],
            temperature = gen.temperature,
            seed= gen.seed,
            #stream=True,
            #echo=True
        )
        """
        for chunk in output:
            delta = chunk['choices'][0]['delta']
            if 'role' in delta:
                print(delta['role'], end=': ')
            elif 'content' in delta:
                print(delta['content'], end='')
            #print(chunk)
        """
        et = time()
        output["time"] = et - st
        return output
    except Exception as e:
        logger.error(f"Error in /generate endpoint: {e}")
        return JSONResponse(
            status_code=500, content={"message": "Internal Server Error"}
        )