Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("monsoon-nlp/gpt-winowhy") | |
model = GPT2LMHeadModel.from_pretrained("monsoon-nlp/gpt-winowhy", pad_token_id=tokenizer.eos_token_id) | |
def hello(prompt, items): | |
inp = prompt.strip() | |
if inp[-1] not in ['?', '!', '.']: | |
inp += '.' | |
inp += ' %' | |
input_ids = torch.tensor([tokenizer.encode(inp)]) | |
output = model.generate(input_ids, max_length=35) | |
resp = tokenizer.decode(output[0], skip_special_tokens=True) | |
if '%' in resp: | |
resp1 = resp[resp.index('%') + 1 : ] | |
names = [] | |
if ',' in items: | |
for item in items.split(','): | |
names.append(item.strip()) | |
else: | |
for word in prompt.split(' '): | |
if word[0] == word[0].upper(): | |
names.append(word) | |
if len(names) > 2: | |
# remove first one which assumedly is a capital | |
names = names[1:] | |
if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))): | |
force_inp = resp1[:resp1.index(names[0])] + names[1] | |
remainder = resp1[resp1.index(names[0]) + len(names[0]):].strip().split(' ') | |
elif (names[1] in resp1): | |
force_inp = resp1[:resp1.index(names[1])] + names[0] | |
remainder = resp1[resp1.index(names[1]) + len(names[1]):].strip().split(' ') | |
else: | |
return [resp1, 'Name not present'] | |
if len(remainder) > 0: | |
if remainder[0] in ['is', 'are', 'was', 'were']: | |
force_inp += ' ' + ' '.join(remainder[:2]) | |
else: | |
force_inp += ' ' + remainder[0] | |
alt = inp + ' ' + force_inp | |
input_ids2 = torch.tensor([tokenizer.encode(alt)]) | |
output2 = model.generate(input_ids2, max_new_tokens=30, min_length=30, do_sample=True) | |
resp2 = tokenizer.decode(output2[0], skip_special_tokens=True) | |
resp2 = resp2[resp2.index('%') + 1 : ] | |
return [resp1, resp2] | |
io = gr.Interface(fn=hello, | |
inputs=[ | |
gr.inputs.Textbox(label="WinoWhy Prompt"), | |
gr.inputs.Textbox(label="Answers (optional)"), | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Fine-tuned reply"), | |
gr.outputs.Textbox(label="Alternative reply"), | |
], | |
verbose=True, | |
title='Anti-Explanations', | |
description='Learn more at https://medium.com/nerd-for-tech/searching-for-anti-explanations-418d26816b44', | |
#thumbnail='https://github.com/MonsoonNLP/gradio-gptnyc', | |
analytics_enabled=True) | |
io.launch(debug=True) | |