p208p2002's picture
Update
ce686de
raw
history blame
1.36 kB
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI,Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from transformer_qa_decode import TransformerQADecode
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from pydantic import BaseModel
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
qahl = TransformerQADecode(model=model, tokenizer=tokenizer, is_squad_v2=True)
app = FastAPI()
app.mount("/static", StaticFiles(directory="react-qa/build/static"), name="static")
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class QAItem(BaseModel):
question:str
context:str
# https://hf.space/embed/{user}/{space}
@app.get("/")
def read_root():
html_content = open('react-qa/build/index.html','r').read()
return HTMLResponse(content=html_content,status_code=200)
@app.post("/question-answer")
def read_item(item:QAItem):
result = qahl(item.question, item.context)
# convert to dict
for r in result:
for i,x in enumerate(r):
x_dict = x._asdict()
r[i] = x_dict
return result