import os import time import json import gradio as gr from threading import Lock from ctransformers import AutoModelForCausalLM from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional, Dict, Any import spaces # Variabili globali model = None status_message = "Modello non ancora caricato" MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf" MODEL_TYPE = "mistral" MAX_NEW_TOKENS = 2048 MODEL_LOCK = Lock() # Pydantic models class Message(BaseModel): role: str content: str class CompletionRequest(BaseModel): model: str messages: List[Message] temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.95 max_tokens: Optional[int] = 2048 stream: Optional[bool] = False stop: Optional[List[str]] = None class CompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[Dict[str, Any]] usage: Dict[str, int] def format_chat_prompt(messages: List[Message]) -> str: conversation = [] for message in messages: if message.role == "system" or message.role == "user": conversation.append(f"[INST] {message.content} [/INST]") elif message.role == "assistant": conversation.append(f"{message.content}") return "".join(conversation) def load_model(): global model, status_message try: status_message = "Caricamento modello in corso..." model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, model_file=MODEL_FILE, model_type=MODEL_TYPE, context_length=4096, threads=4 ) status_message = "Modello caricato con successo" return True except Exception as e: status_message = f"Errore nel caricamento del modello: {str(e)}" return False def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS): global model, status_message if model is None: if not load_model(): return status_message with MODEL_LOCK: try: result = model( prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=1.1 ) return result except Exception as e: return f"Errore nella generazione: {str(e)}" def generate_with_timing(text, temp, max_tok): start_time = time.time() prompt = f"[INST] {text} [/INST]" result = generate_response(prompt, temperature=temp, max_tokens=max_tok) end_time = time.time() return result, f"{end_time - start_time:.2f} secondi" def create_gradio_interface(): with gr.Blocks(title="Mistral API") as interface: gr.Markdown("# Mistral-7B API Server") with gr.Row(): with gr.Column(): status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False) load_button = gr.Button("Carica Modello") load_button.click(load_model, inputs=[], outputs=[]) with gr.Row(): with gr.Column(): input_text = gr.Textbox(lines=5, label="Input", placeholder="Inserisci il tuo messaggio qui...") with gr.Row(): temp_slider = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperatura") max_token_slider = gr.Slider(100, MAX_NEW_TOKENS, value=1024, step=100, label="Max Token") submit_button = gr.Button("Genera") with gr.Column(): output_text = gr.Textbox(lines=12, label="Risposta del modello") gen_time = gr.Textbox(label="Tempo di generazione", interactive=False) submit_button.click( generate_with_timing, inputs=[input_text, temp_slider, max_token_slider], outputs=[output_text, gen_time] ) gr.Markdown(""" ## API Endpoint Questa applicazione espone un endpoint API compatibile con OpenAI: - `/v1/chat/completions` - Per richieste di completamento chat - `/status` - Per verificare lo stato del modello """) return interface @spaces.GPU def get_gpu(): return "GPU allocata con successo" app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/v1/chat/completions", response_model=CompletionResponse) async def create_completion(request: CompletionRequest): try: prompt = format_chat_prompt(request.messages) max_tokens = min(request.max_tokens, MAX_NEW_TOKENS) start_time = time.time() completion_text = generate_response( prompt, temperature=request.temperature, top_p=request.top_p, max_tokens=max_tokens ) end_time = time.time() input_tokens = len(prompt.split()) output_tokens = len(completion_text.split()) response = { "id": f"chatcmpl-{os.urandom(4).hex()}", "object": "chat.completion", "created": int(time.time()), "model": request.model, "choices": [ { "index": 0, "message": { "role": "assistant", "content": completion_text, }, "finish_reason": "stop", } ], "usage": { "prompt_tokens": input_tokens, "completion_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, } } return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/status") async def get_status(): return {"status": status_message, "model": MODEL_PATH} # Crea l'interfaccia Gradio e monta su un path dedicato per evitare errori statici interface = create_gradio_interface() app = gr.mount_gradio_app(app, interface, path="/gradio") @app.on_event("startup") async def startup_load_model(): get_gpu() load_model() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")