jackbean's picture
Update app.py
bc3b119 verified
Raw
History Blame Contribute Delete
1.79 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
# Initialize FastAPI app
app = FastAPI()
# Lazy load model and tokenizer
model = None
tokenizer = None
def load_model():
global model, tokenizer
if model is None or tokenizer is None:
tokenizer = T5Tokenizer.from_pretrained('./tokenizer12')
model = T5ForConditionalGeneration.from_pretrained('./model')
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# Request body schema using Pydantic
class QuestionRequest(BaseModel):
context: str
answer: str
from fastapi import Query
@app.post("/generate_question")
async def generate_question(request: QuestionRequest):
load_model()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_text = f"context: {request.context} answer: {request.answer}"
encoding = tokenizer.encode_plus(
input_text,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
model.eval()
with torch.no_grad():
beam_outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=72,
early_stopping=True,
num_beams=5,
num_return_sequences=3
)
return {
"generated_questions": [
tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for output in beam_outputs
]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)