ubeydkhoiri's picture
Upload 3 files
23d1168 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
# Create a new FastAPI app instance
app = FastAPI()
# Load pre-trained model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
class GenerateRequest(BaseModel):
text: str
max_length: int = 50 # Optional: Default max length
@app.get("/")
def home():
return {"message": "Hello, this is an example of text generation using GPT-2."}
@app.post("/generate")
async def generate(request: GenerateRequest):
try:
# Encode input text
input_ids = tokenizer.encode(request.text, return_tensors="pt")
# Generate text
output = model.generate(input_ids, max_length=request.max_length, num_return_sequences=1)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Return the generated text in JSON response
return {"output": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Warm up the model to reduce the latency for the first request
@app.on_event("startup")
async def warm_up_model():
input_ids = tokenizer.encode("Warm-up", return_tensors="pt")
with torch.no_grad():
_ = model.generate(input_ids, max_length=10)