AntiExplanation / app.py
monsoon-nlp's picture
push sentence length
6e9d2dd
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("monsoon-nlp/gpt-winowhy")
model = GPT2LMHeadModel.from_pretrained("monsoon-nlp/gpt-winowhy", pad_token_id=tokenizer.eos_token_id)
def hello(prompt, items):
inp = prompt.strip()
if inp[-1] not in ['?', '!', '.']:
inp += '.'
inp += ' %'
input_ids = torch.tensor([tokenizer.encode(inp)])
output = model.generate(input_ids, max_length=35)
resp = tokenizer.decode(output[0], skip_special_tokens=True)
if '%' in resp:
resp1 = resp[resp.index('%') + 1 : ]
names = []
if ',' in items:
for item in items.split(','):
names.append(item.strip())
else:
for word in prompt.split(' '):
if word[0] == word[0].upper():
names.append(word)
if len(names) > 2:
# remove first one which assumedly is a capital
names = names[1:]
if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
force_inp = resp1[:resp1.index(names[0])] + names[1]
remainder = resp1[resp1.index(names[0]) + len(names[0]):].strip().split(' ')
elif (names[1] in resp1):
force_inp = resp1[:resp1.index(names[1])] + names[0]
remainder = resp1[resp1.index(names[1]) + len(names[1]):].strip().split(' ')
else:
return [resp1, 'Name not present']
if len(remainder) > 0:
if remainder[0] in ['is', 'are', 'was', 'were']:
force_inp += ' ' + ' '.join(remainder[:2])
else:
force_inp += ' ' + remainder[0]
alt = inp + ' ' + force_inp
input_ids2 = torch.tensor([tokenizer.encode(alt)])
output2 = model.generate(input_ids2, max_new_tokens=30, min_length=30, do_sample=True)
resp2 = tokenizer.decode(output2[0], skip_special_tokens=True)
resp2 = resp2[resp2.index('%') + 1 : ]
return [resp1, resp2]
io = gr.Interface(fn=hello,
inputs=[
gr.inputs.Textbox(label="WinoWhy Prompt"),
gr.inputs.Textbox(label="Answers (optional)"),
],
outputs=[
gr.outputs.Textbox(label="Fine-tuned reply"),
gr.outputs.Textbox(label="Alternative reply"),
],
verbose=True,
title='Anti-Explanations',
description='Learn more at https://medium.com/nerd-for-tech/searching-for-anti-explanations-418d26816b44',
#thumbnail='https://github.com/MonsoonNLP/gradio-gptnyc',
analytics_enabled=True)
io.launch(debug=True)