Mistral-7B-API / app.py
smartdigitalsolutions's picture
Update app.py
5e9fcd0 verified
import os
import time
import json
import gradio as gr
from threading import Lock
from ctransformers import AutoModelForCausalLM
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import spaces
# Variabili globali
model = None
status_message = "Modello non ancora caricato"
MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf"
MODEL_TYPE = "mistral"
MAX_NEW_TOKENS = 2048
MODEL_LOCK = Lock()
# Pydantic models
class Message(BaseModel):
role: str
content: str
class CompletionRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
max_tokens: Optional[int] = 2048
stream: Optional[bool] = False
stop: Optional[List[str]] = None
class CompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
def format_chat_prompt(messages: List[Message]) -> str:
conversation = []
for message in messages:
if message.role == "system" or message.role == "user":
conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
elif message.role == "assistant":
conversation.append(f"<s>{message.content}</s>")
return "".join(conversation)
def load_model():
global model, status_message
try:
status_message = "Caricamento modello in corso..."
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
model_file=MODEL_FILE,
model_type=MODEL_TYPE,
context_length=4096,
threads=4
)
status_message = "Modello caricato con successo"
return True
except Exception as e:
status_message = f"Errore nel caricamento del modello: {str(e)}"
return False
def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS):
global model, status_message
if model is None:
if not load_model():
return status_message
with MODEL_LOCK:
try:
result = model(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.1
)
return result
except Exception as e:
return f"Errore nella generazione: {str(e)}"
def generate_with_timing(text, temp, max_tok):
start_time = time.time()
prompt = f"<s>[INST] {text} [/INST]</s>"
result = generate_response(prompt, temperature=temp, max_tokens=max_tok)
end_time = time.time()
return result, f"{end_time - start_time:.2f} secondi"
def create_gradio_interface():
with gr.Blocks(title="Mistral API") as interface:
gr.Markdown("# Mistral-7B API Server")
with gr.Row():
with gr.Column():
status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False)
load_button = gr.Button("Carica Modello")
load_button.click(load_model, inputs=[], outputs=[])
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=5, label="Input", placeholder="Inserisci il tuo messaggio qui...")
with gr.Row():
temp_slider = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperatura")
max_token_slider = gr.Slider(100, MAX_NEW_TOKENS, value=1024, step=100, label="Max Token")
submit_button = gr.Button("Genera")
with gr.Column():
output_text = gr.Textbox(lines=12, label="Risposta del modello")
gen_time = gr.Textbox(label="Tempo di generazione", interactive=False)
submit_button.click(
generate_with_timing,
inputs=[input_text, temp_slider, max_token_slider],
outputs=[output_text, gen_time]
)
gr.Markdown("""
## API Endpoint
Questa applicazione espone un endpoint API compatibile con OpenAI:
- `/v1/chat/completions` - Per richieste di completamento chat
- `/status` - Per verificare lo stato del modello
""")
return interface
@spaces.GPU
def get_gpu():
return "GPU allocata con successo"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def create_completion(request: CompletionRequest):
try:
prompt = format_chat_prompt(request.messages)
max_tokens = min(request.max_tokens, MAX_NEW_TOKENS)
start_time = time.time()
completion_text = generate_response(
prompt,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=max_tokens
)
end_time = time.time()
input_tokens = len(prompt.split())
output_tokens = len(completion_text.split())
response = {
"id": f"chatcmpl-{os.urandom(4).hex()}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": completion_text,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status")
async def get_status():
return {"status": status_message, "model": MODEL_PATH}
# Crea l'interfaccia Gradio e monta su un path dedicato per evitare errori statici
interface = create_gradio_interface()
app = gr.mount_gradio_app(app, interface, path="/gradio")
@app.on_event("startup")
async def startup_load_model():
get_gpu()
load_model()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")