|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from gradio import Interface |
|
|
|
|
|
model_name = "facebook/bart-base" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
def generate_questions(email): |
|
"""Generates questions based on the input email.""" |
|
|
|
inputs = tokenizer(email, return_tensors="pt", add_special_tokens=True) |
|
|
|
|
|
inputs["input_ids"] = [tokenizer.cls_token_id] + inputs["input_ids"].tolist() |
|
|
|
|
|
generation = model.generate( |
|
**inputs, |
|
max_length=256, |
|
num_beams=5, |
|
early_stopping=True, |
|
) |
|
|
|
|
|
return tokenizer.decode(generation[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
def generate_answers(questions): |
|
"""Generates possible answers to the input questions.""" |
|
|
|
inputs = tokenizer("\n".join(questions), return_tensors="pt") |
|
|
|
|
|
generation = model.generate( |
|
input_ids=inputs["input_ids"], |
|
max_length=512, |
|
num_beams=3, |
|
early_stopping=True, |
|
prompt="Here are some possible answers to the questions:\n", |
|
) |
|
|
|
|
|
answers = tokenizer.decode(generation[0], skip_special_tokens=True).split("\n") |
|
return zip(questions, answers[1:]) |
|
|
|
def gradio_app(email): |
|
"""Gradio interface function""" |
|
questions = generate_questions(email) |
|
answers = generate_answers(questions.split("\n")) |
|
return questions, [answer for _, answer in answers] |
|
|
|
|
|
|
|
interface = Interface( |
|
fn=gradio_app, |
|
inputs="textbox", |
|
outputs=["text", "text"], |
|
title="AI Email Assistant", |
|
description="Enter a long email and get questions and possible answers generated by an AI model.", |
|
elem_id="email-input" |
|
) |
|
|
|
|
|
|
|
interface.launch() |
|
|