Spaces:
Sleeping
Sleeping
File size: 2,531 Bytes
051ee94 178c299 051ee94 3e79d11 051ee94 178c299 d98d0a8 3e79d11 d98d0a8 3e79d11 d98d0a8 6e9d2dd 684df16 051ee94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("monsoon-nlp/gpt-winowhy")
model = GPT2LMHeadModel.from_pretrained("monsoon-nlp/gpt-winowhy", pad_token_id=tokenizer.eos_token_id)
def hello(prompt, items):
inp = prompt.strip()
if inp[-1] not in ['?', '!', '.']:
inp += '.'
inp += ' %'
input_ids = torch.tensor([tokenizer.encode(inp)])
output = model.generate(input_ids, max_length=35)
resp = tokenizer.decode(output[0], skip_special_tokens=True)
if '%' in resp:
resp1 = resp[resp.index('%') + 1 : ]
names = []
if ',' in items:
for item in items.split(','):
names.append(item.strip())
else:
for word in prompt.split(' '):
if word[0] == word[0].upper():
names.append(word)
if len(names) > 2:
# remove first one which assumedly is a capital
names = names[1:]
if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
force_inp = resp1[:resp1.index(names[0])] + names[1]
remainder = resp1[resp1.index(names[0]) + len(names[0]):].strip().split(' ')
elif (names[1] in resp1):
force_inp = resp1[:resp1.index(names[1])] + names[0]
remainder = resp1[resp1.index(names[1]) + len(names[1]):].strip().split(' ')
else:
return [resp1, 'Name not present']
if len(remainder) > 0:
if remainder[0] in ['is', 'are', 'was', 'were']:
force_inp += ' ' + ' '.join(remainder[:2])
else:
force_inp += ' ' + remainder[0]
alt = inp + ' ' + force_inp
input_ids2 = torch.tensor([tokenizer.encode(alt)])
output2 = model.generate(input_ids2, max_new_tokens=30, min_length=30, do_sample=True)
resp2 = tokenizer.decode(output2[0], skip_special_tokens=True)
resp2 = resp2[resp2.index('%') + 1 : ]
return [resp1, resp2]
io = gr.Interface(fn=hello,
inputs=[
gr.inputs.Textbox(label="WinoWhy Prompt"),
gr.inputs.Textbox(label="Answers (optional)"),
],
outputs=[
gr.outputs.Textbox(label="Fine-tuned reply"),
gr.outputs.Textbox(label="Alternative reply"),
],
verbose=True,
title='Anti-Explanations',
description='Learn more at https://medium.com/nerd-for-tech/searching-for-anti-explanations-418d26816b44',
#thumbnail='https://github.com/MonsoonNLP/gradio-gptnyc',
analytics_enabled=True)
io.launch(debug=True)
|