DauroCamilo's picture
Update main.py
c45c1b6 verified
raw
history blame
1.66 kB
import os
import torch
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel, TextIteratorStreamer
from fastapi.responses import StreamingResponse
import threading
app = FastAPI()
model_id = "GEB-AGI/geb-1.3b"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ChatRequest(BaseModel):
message: str
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest):
prompt = f"Responde en español de forma clara y breve como un asistente IA.\nUsuario: {request.message}\nIA:"
# 1. Tokeniza a tokens (sin padding, sin encode)
tokens = tokenizer.tokenize(prompt)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
# 2. Añade manualmente los tokens especiales
input_ids = tokenizer.build_inputs_with_special_tokens(token_ids)
input_ids = torch.tensor([input_ids])
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=input_ids,
max_new_tokens=48,
temperature=0.7,
top_p=0.9,
do_sample=True,
streamer=streamer,
pad_token_id=getattr(tokenizer, "eos_token_id", None),
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
async def event_generator():
for new_text in streamer:
yield new_text
return StreamingResponse(event_generator(), media_type="text/plain")