Elrmnd's picture
Update app.py
f259830 verified
raw
history blame contribute delete
No virus
2.34 kB
import gradio as gr
import torch
from transformers import BertTokenizer, BertForMaskedLM
# Load the fine-tuned BERT model
model_name = "fine_tuned_bert_model"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)
model.to("cuda" if torch.cuda.is_available() else "cpu")
# Function to answer questions using the fine-tuned model
def answer_question(context, question):
# Preprocess the context and question
context_tokens = tokenizer(context, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
question_tokens = tokenizer(question, truncation=True, padding="max_length", max_length=16, return_tensors="pt")
# Move tensors to device
context_tokens = context_tokens.to(model.device)
question_tokens = question_tokens.to(model.device)
with torch.no_grad():
# Generate masked LM predictions for each token in the question
outputs = model(**question_tokens)
predictions = torch.argmax(outputs.logits, dim=-1)
# Replace masked tokens in the question with predicted tokens
answer_tokens = []
for i in range(len(question_tokens["input_ids"][0])):
if question_tokens["input_ids"][0][i] == tokenizer.mask_token_id:
answer_tokens.append(predictions[0][i].item())
else:
answer_tokens.append(question_tokens["input_ids"][0][i].item())
# Decode tokens and remove special tokens
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
# Return the answer
return answer
# Define example questions
examples = [
["Where did the Enron scandal occur?", "The Enron scandal occurred in [MASK]."],
["What was the outcome of the Enron scandal?", "The outcome of the Enron scandal was [MASK]."],
["When did Enron file for bankruptcy?", "Enron filed for bankruptcy in [MASK]."],
["How did Enron's stock price change during the scandal?", "During the Enron scandal, Enron's stock price [MASK]."]
]
# Gradio interface with examples
iface = gr.Interface(
fn=answer_question,
inputs=["text", "text"],
outputs="text",
title="Enron Email Analysis",
description="Ask questions about the Enron email dataset using a fine-tuned BERT model.",
examples=examples
)
# Launch the Gradio interface
iface.launch(share=True)