Spaces:
Paused
Paused
import gradio as gr | |
from transformers import BertForQuestionAnswering | |
from transformers import BertTokenizerFast | |
import torch | |
from nltk.tokenize import word_tokenize | |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def get_prediction(context, question): | |
inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device) | |
outputs = model(**inputs) | |
answer_start = torch.argmax(outputs[0]) | |
answer_end = torch.argmax(outputs[1]) + 1 | |
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) | |
return answer | |
def question_answer(context, question): | |
prediction = get_prediction(context,question) | |
return prediction | |
def split(text): | |
words = word_tokenize(text) | |
# context, question = '', '' | |
# act = False | |
# for w in words: | |
# if w == '///': | |
# act = True | |
# if act == False: | |
# context += w + ' ' | |
# else: | |
# if w == '///': | |
# w = '' | |
# question += w + ' ' | |
# context = context[:-1] | |
# question = question[1:-1] | |
return text, words | |
# def greet(texts): | |
# context, question = split(texts) | |
# answer = question_answer(context, question) | |
# return answer | |
def greet(text): | |
context, question = split(text) | |
# answer = question_answer(context, question) | |
return question | |
iface = gr.Interface(fn=greet, inputs="text", outputs="text") | |
iface.launch() |