|
import gradio as gr |
|
import torch |
|
from transformers import BertTokenizer, BertForMaskedLM |
|
|
|
|
|
model_name = "/content/fine_tuned_bert_model" |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
model = BertForMaskedLM.from_pretrained(model_name) |
|
model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def answer_question(context, question): |
|
|
|
context_tokens = tokenizer(context, truncation=True, padding="max_length", max_length=128, return_tensors="pt") |
|
question_tokens = tokenizer(question, truncation=True, padding="max_length", max_length=16, return_tensors="pt") |
|
|
|
|
|
context_tokens = context_tokens.to(model.device) |
|
question_tokens = question_tokens.to(model.device) |
|
|
|
with torch.no_grad(): |
|
|
|
outputs = model(**question_tokens) |
|
predictions = torch.argmax(outputs.logits, dim=-1) |
|
|
|
|
|
answer_tokens = [] |
|
for i in range(len(question_tokens["input_ids"][0])): |
|
if question_tokens["input_ids"][0][i] == tokenizer.mask_token_id: |
|
answer_tokens.append(predictions[0][i].item()) |
|
else: |
|
answer_tokens.append(question_tokens["input_ids"][0][i].item()) |
|
|
|
|
|
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True) |
|
|
|
|
|
return answer |
|
|
|
|
|
examples = [ |
|
["Where did the Enron scandal occur?", "The Enron scandal occurred in [MASK]."], |
|
["What was the outcome of the Enron scandal?", "The outcome of the Enron scandal was [MASK]."], |
|
["When did Enron file for bankruptcy?", "Enron filed for bankruptcy in [MASK]."], |
|
["How did Enron's stock price change during the scandal?", "During the Enron scandal, Enron's stock price [MASK]."] |
|
] |
|
|
|
|
|
iface = gr.Interface( |
|
fn=answer_question, |
|
inputs=["text", "text"], |
|
outputs="text", |
|
title="Enron Email Analysis", |
|
description="Ask questions about the Enron email dataset using a fine-tuned BERT model.", |
|
examples=examples |
|
) |
|
|
|
|
|
iface.launch(share=True) |