GLM-4.6-FP8-API / app.py
AARANHA's picture
Create app.py
8c89b4a verified
raw
history blame
2.99 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import uvicorn
import os
app = FastAPI(
title="GLM-4.6-FP8 API",
description="API REST funcional para GLM-4.6-FP8 com suporte a múltiplas linguagens",
version="1.0.0"
)
# Modelos cache
model = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
class ChatRequest(BaseModel):
message: str
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
class ChatResponse(BaseModel):
response: str
model: str = "GLM-4.6-FP8"
device: str = device
@app.on_event("startup")
async def startup_event():
global model, tokenizer
try:
print("Carregando modelo GLM-4.6-FP8...")
tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.6-FP8")
model = AutoModelForCausalLM.from_pretrained(
"zai-org/GLM-4.6-FP8",
device_map="auto",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
trust_remote_code=True
)
print("Modelo carregado com sucesso!")
except Exception as e:
print(f"Erro ao carregar modelo: {e}")
raise
@app.get("/")
async def root():
return {
"message": "GLM-4.6-FP8 API",
"version": "1.0.0",
"device": device,
"model_loaded": model is not None,
"endpoints": {
"chat": "/chat",
"generate": "/generate",
"health": "/health"
}
}
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": model is not None,
"device": device
}
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
global model, tokenizer
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Modelo não está carregado")
try:
# Tokenizar entrada
inputs = tokenizer(request.message, return_tensors="pt").to(device)
# Gerar resposta
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
do_sample=True
)
# Decodificar resposta
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return ChatResponse(response=response_text)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erro na geração: {str(e)}")
@app.post("/generate", response_model=ChatResponse)
async def generate(request: ChatRequest):
"""Alias para /chat com formato alternativo"""
return await chat(request)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)