|
|
import os |
|
|
from typing import List, Literal, Optional |
|
|
|
|
|
import torch |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "MBZUAI-Paris/Nile-Chat-12B") |
|
|
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024")) |
|
|
|
|
|
app = FastAPI(title="Nile-Chat-12B FastAPI") |
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Role = Literal["system", "user", "assistant"] |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: Role |
|
|
content: str |
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
|
|
|
|
|
|
messages: List[ChatMessage] = Field(..., description="Conversation messages in OpenAI-like format") |
|
|
|
|
|
max_new_tokens: int = Field(DEFAULT_MAX_NEW_TOKENS, ge=1, le=MAX_MAX_NEW_TOKENS) |
|
|
do_sample: bool = True |
|
|
temperature: float = Field(0.6, ge=0.0, le=4.0) |
|
|
top_p: float = Field(0.9, ge=0.05, le=1.0) |
|
|
top_k: int = Field(50, ge=1, le=1000) |
|
|
repetition_penalty: float = Field(1.1, ge=1.0, le=2.0) |
|
|
|
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
response: str |
|
|
trimmed: bool = False |
|
|
model: str = MODEL_NAME |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
def startup_event(): |
|
|
global tokenizer, model |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
device_map="auto", |
|
|
torch_dtype=dtype, |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
print("Model ready") |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return {"status": "ok", "model": MODEL_NAME} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse) |
|
|
def generate(req: GenerateRequest): |
|
|
global tokenizer, model |
|
|
|
|
|
if not req.messages: |
|
|
return GenerateResponse(response="Error: messages is empty", trimmed=False) |
|
|
|
|
|
|
|
|
conversation = [m.model_dump() for m in req.messages] |
|
|
|
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
conversation, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
trimmed = False |
|
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
|
trimmed = True |
|
|
|
|
|
input_ids = input_ids.to(model.device) |
|
|
|
|
|
|
|
|
last_user = next((m.content for m in reversed(req.messages) if m.role == "user"), "") |
|
|
print("\n=== Incoming Request ===") |
|
|
print("MODEL:", MODEL_NAME) |
|
|
print("LAST USER:", last_user) |
|
|
print("trimmed_input:", trimmed) |
|
|
print("input_tokens:", int(input_ids.shape[1])) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=req.max_new_tokens, |
|
|
do_sample=req.do_sample, |
|
|
top_p=req.top_p, |
|
|
top_k=req.top_k, |
|
|
temperature=req.temperature, |
|
|
num_beams=1, |
|
|
repetition_penalty=req.repetition_penalty, |
|
|
) |
|
|
|
|
|
|
|
|
new_tokens = out[0, input_ids.shape[-1]:] |
|
|
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
|
|
|
|
|
print("\n=== Model Response ===") |
|
|
print(response_text) |
|
|
print("======================\n") |
|
|
|
|
|
return GenerateResponse(response=response_text, trimmed=trimmed, model=MODEL_NAME) |
|
|
|