File size: 1,248 Bytes
b238a3c 2267429 b238a3c 2267429 b238a3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Initialize FastAPI app
app = FastAPI()
# Load pre-trained DistilGPT-2 model and tokenizer, using from_tf=True to load TensorFlow weights
model_name = "distilgpt2" # Smaller GPT-2 model
model = AutoModelForCausalLM.from_pretrained(model_name, from_tf=True) # Use from_tf=True
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Pydantic model for request body
class TextRequest(BaseModel):
text: str
# Route to generate text
@app.post("/generate/")
async def generate_text(request: TextRequest):
# Encode the input text
inputs = tokenizer.encode(request.text, return_tensors="pt")
# Generate a response from the model
with torch.no_grad():
outputs = model.generate(inputs, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_p=0.9, top_k=50)
# Decode the generated response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": response}
# Optionally, you can add a root endpoint for checking server health
@app.get("/")
async def read_root():
return {"message": "Welcome to the GPT-2 FastAPI server!"}
|