|
from flask import Flask, request, render_template
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
|
import torch
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
model = T5ForConditionalGeneration.from_pretrained('./finetuned_t5')
|
|
tokenizer = T5Tokenizer.from_pretrained('./finetuned_t5')
|
|
model.eval()
|
|
|
|
@app.route('/', methods=['GET', 'POST'])
|
|
def index():
|
|
answer = ""
|
|
if request.method == 'POST':
|
|
question = request.form['question']
|
|
input_text = f"question: {question.strip()}"
|
|
inputs = tokenizer(input_text, max_length=128, truncation=True, padding=True, return_tensors="pt")
|
|
outputs = model.generate(
|
|
inputs['input_ids'],
|
|
max_length=64,
|
|
num_beams=4,
|
|
early_stopping=True,
|
|
no_repeat_ngram_size=2
|
|
)
|
|
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return render_template('index.html', answer=answer)
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True) |