|
import os |
|
import time |
|
import json |
|
import gradio as gr |
|
from threading import Lock |
|
from ctransformers import AutoModelForCausalLM |
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
from typing import List, Optional, Dict, Any |
|
import spaces |
|
|
|
|
|
model = None |
|
status_message = "Modello non ancora caricato" |
|
MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" |
|
MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf" |
|
MODEL_TYPE = "mistral" |
|
MAX_NEW_TOKENS = 2048 |
|
MODEL_LOCK = Lock() |
|
|
|
|
|
class Message(BaseModel): |
|
role: str |
|
content: str |
|
|
|
class CompletionRequest(BaseModel): |
|
model: str |
|
messages: List[Message] |
|
temperature: Optional[float] = 0.7 |
|
top_p: Optional[float] = 0.95 |
|
max_tokens: Optional[int] = 2048 |
|
stream: Optional[bool] = False |
|
stop: Optional[List[str]] = None |
|
|
|
class CompletionResponse(BaseModel): |
|
id: str |
|
object: str = "chat.completion" |
|
created: int |
|
model: str |
|
choices: List[Dict[str, Any]] |
|
usage: Dict[str, int] |
|
|
|
def format_chat_prompt(messages: List[Message]) -> str: |
|
conversation = [] |
|
for message in messages: |
|
if message.role == "system" or message.role == "user": |
|
conversation.append(f"<s>[INST] {message.content} [/INST]</s>") |
|
elif message.role == "assistant": |
|
conversation.append(f"<s>{message.content}</s>") |
|
return "".join(conversation) |
|
|
|
def load_model(): |
|
global model, status_message |
|
try: |
|
status_message = "Caricamento modello in corso..." |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, |
|
model_file=MODEL_FILE, |
|
model_type=MODEL_TYPE, |
|
context_length=4096, |
|
threads=4 |
|
) |
|
status_message = "Modello caricato con successo" |
|
return True |
|
except Exception as e: |
|
status_message = f"Errore nel caricamento del modello: {str(e)}" |
|
return False |
|
|
|
def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS): |
|
global model, status_message |
|
if model is None: |
|
if not load_model(): |
|
return status_message |
|
with MODEL_LOCK: |
|
try: |
|
result = model( |
|
prompt, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=1.1 |
|
) |
|
return result |
|
except Exception as e: |
|
return f"Errore nella generazione: {str(e)}" |
|
|
|
def generate_with_timing(text, temp, max_tok): |
|
start_time = time.time() |
|
prompt = f"<s>[INST] {text} [/INST]</s>" |
|
result = generate_response(prompt, temperature=temp, max_tokens=max_tok) |
|
end_time = time.time() |
|
return result, f"{end_time - start_time:.2f} secondi" |
|
|
|
def create_gradio_interface(): |
|
with gr.Blocks(title="Mistral API") as interface: |
|
gr.Markdown("# Mistral-7B API Server") |
|
with gr.Row(): |
|
with gr.Column(): |
|
status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False) |
|
load_button = gr.Button("Carica Modello") |
|
load_button.click(load_model, inputs=[], outputs=[]) |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(lines=5, label="Input", placeholder="Inserisci il tuo messaggio qui...") |
|
with gr.Row(): |
|
temp_slider = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperatura") |
|
max_token_slider = gr.Slider(100, MAX_NEW_TOKENS, value=1024, step=100, label="Max Token") |
|
submit_button = gr.Button("Genera") |
|
with gr.Column(): |
|
output_text = gr.Textbox(lines=12, label="Risposta del modello") |
|
gen_time = gr.Textbox(label="Tempo di generazione", interactive=False) |
|
submit_button.click( |
|
generate_with_timing, |
|
inputs=[input_text, temp_slider, max_token_slider], |
|
outputs=[output_text, gen_time] |
|
) |
|
gr.Markdown(""" |
|
## API Endpoint |
|
Questa applicazione espone un endpoint API compatibile con OpenAI: |
|
- `/v1/chat/completions` - Per richieste di completamento chat |
|
- `/status` - Per verificare lo stato del modello |
|
""") |
|
return interface |
|
|
|
@spaces.GPU |
|
def get_gpu(): |
|
return "GPU allocata con successo" |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.post("/v1/chat/completions", response_model=CompletionResponse) |
|
async def create_completion(request: CompletionRequest): |
|
try: |
|
prompt = format_chat_prompt(request.messages) |
|
max_tokens = min(request.max_tokens, MAX_NEW_TOKENS) |
|
start_time = time.time() |
|
completion_text = generate_response( |
|
prompt, |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
max_tokens=max_tokens |
|
) |
|
end_time = time.time() |
|
input_tokens = len(prompt.split()) |
|
output_tokens = len(completion_text.split()) |
|
response = { |
|
"id": f"chatcmpl-{os.urandom(4).hex()}", |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": request.model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": completion_text, |
|
}, |
|
"finish_reason": "stop", |
|
} |
|
], |
|
"usage": { |
|
"prompt_tokens": input_tokens, |
|
"completion_tokens": output_tokens, |
|
"total_tokens": input_tokens + output_tokens, |
|
} |
|
} |
|
return response |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/status") |
|
async def get_status(): |
|
return {"status": status_message, "model": MODEL_PATH} |
|
|
|
|
|
interface = create_gradio_interface() |
|
app = gr.mount_gradio_app(app, interface, path="/gradio") |
|
|
|
@app.on_event("startup") |
|
async def startup_load_model(): |
|
get_gpu() |
|
load_model() |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |
|
|