test-api / main.py
Mr-Help's picture
Update main.py
de5fced verified
import os
from typing import List, Literal, Optional
import torch
from fastapi import FastAPI
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
# ----------------------------
# Model config (matches demo)
# ----------------------------
MODEL_NAME = os.getenv("MODEL_NAME", "MBZUAI-Paris/Nile-Chat-12B")
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024"))
app = FastAPI(title="Nile-Chat-12B FastAPI")
tokenizer = None
model = None
# ----------------------------
# Request schemas
# ----------------------------
Role = Literal["system", "user", "assistant"]
class ChatMessage(BaseModel):
role: Role
content: str
class GenerateRequest(BaseModel):
# ู†ูุณ ู…ูู‡ูˆู… Gradio: history + message
# ู„ูƒู† ู‡ู†ุง ู‡ู†ูˆุญู‘ุฏู‡ุง: messages ูƒุงู…ู„ุฉุŒ ูˆุขุฎุฑ user message ู‡ูŠ ุงู„ุทู„ุจ ุงู„ุญุงู„ูŠ
messages: List[ChatMessage] = Field(..., description="Conversation messages in OpenAI-like format")
max_new_tokens: int = Field(DEFAULT_MAX_NEW_TOKENS, ge=1, le=MAX_MAX_NEW_TOKENS)
do_sample: bool = True
temperature: float = Field(0.6, ge=0.0, le=4.0)
top_p: float = Field(0.9, ge=0.05, le=1.0)
top_k: int = Field(50, ge=1, le=1000)
repetition_penalty: float = Field(1.1, ge=1.0, le=2.0)
class GenerateResponse(BaseModel):
response: str
trimmed: bool = False
model: str = MODEL_NAME
# ----------------------------
# Startup
# ----------------------------
@app.on_event("startup")
def startup_event():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# ู†ูุณ ู…ู†ุทู‚ ุงู„ุฏูŠู…ูˆ: bfloat16 + device_map auto
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=dtype,
)
model.eval()
print("Model ready") # โœ… ุฒูŠ ู…ุง ุทู„ุจุช
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME}
# ----------------------------
# Core generation
# ----------------------------
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
global tokenizer, model
if not req.messages:
return GenerateResponse(response="Error: messages is empty", trimmed=False)
# Nile-Chat demo ุจูŠุณุชุฎุฏู… apply_chat_template ุนู„ู‰ conversation ูƒู„ู‡ุง
conversation = [m.model_dump() for m in req.messages]
# Build input_ids exactly like the Gradio demo
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
)
trimmed = False
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
trimmed = True
input_ids = input_ids.to(model.device)
# Logging
last_user = next((m.content for m in reversed(req.messages) if m.role == "user"), "")
print("\n=== Incoming Request ===")
print("MODEL:", MODEL_NAME)
print("LAST USER:", last_user)
print("trimmed_input:", trimmed)
print("input_tokens:", int(input_ids.shape[1]))
# Generate (non-streaming API response)
with torch.no_grad():
out = model.generate(
input_ids=input_ids,
max_new_tokens=req.max_new_tokens,
do_sample=req.do_sample,
top_p=req.top_p,
top_k=req.top_k,
temperature=req.temperature,
num_beams=1,
repetition_penalty=req.repetition_penalty,
)
# Decode only new tokens (same idea as your Qwen API)
new_tokens = out[0, input_ids.shape[-1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
print("\n=== Model Response ===")
print(response_text)
print("======================\n")
return GenerateResponse(response=response_text, trimmed=trimmed, model=MODEL_NAME)