custom-api / app.py
DataChem's picture
Update app.py
9d98b84 verified
raw
history blame
1.34 kB
from fastapi import FastAPI, Request
from transformers import AutoModelForCausalLM, AutoTokenizer
from fastapi.responses import StreamingResponse
import torch
app = FastAPI()
# Load the model and tokenizer
model_name = "EleutherAI/gpt-neo-1.3B" # Replace with your desired model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/predict")
async def predict(request: Request):
data = await request.json()
prompt = data.get("prompt", "")
if not prompt:
return {"error": "Prompt is required"}
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to("cpu") # Use "cuda" if GPU is enabled
# Generator function to stream tokens
def token_generator():
outputs = model.generate(
inputs.input_ids,
max_length=40,
do_sample=True,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
)
for token_id in outputs[0]:
token = tokenizer.decode(token_id, skip_special_tokens=True)
yield f"{token} "
# Return StreamingResponse
return StreamingResponse(token_generator(), media_type="text/plain")