File size: 2,531 Bytes
051ee94
 
 
 
 
 
 
 
178c299
 
 
 
051ee94
3e79d11
051ee94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178c299
d98d0a8
3e79d11
 
d98d0a8
3e79d11
 
 
 
 
 
 
 
d98d0a8
 
6e9d2dd
684df16
 
051ee94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("monsoon-nlp/gpt-winowhy")
model = GPT2LMHeadModel.from_pretrained("monsoon-nlp/gpt-winowhy", pad_token_id=tokenizer.eos_token_id)

def hello(prompt, items):
    inp = prompt.strip()
    if inp[-1] not in ['?', '!', '.']:
        inp += '.'
    inp += ' %'
    input_ids = torch.tensor([tokenizer.encode(inp)])
    output = model.generate(input_ids, max_length=35)
    resp = tokenizer.decode(output[0], skip_special_tokens=True)
    if '%' in resp:
        resp1 = resp[resp.index('%') + 1 : ]

    names = []
    if ',' in items:
        for item in items.split(','):
            names.append(item.strip())
    else:
        for word in prompt.split(' '):
            if word[0] == word[0].upper():
                names.append(word)
        if len(names) > 2:
            # remove first one which assumedly is a capital
            names = names[1:]

    if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
        force_inp = resp1[:resp1.index(names[0])] + names[1]
        remainder = resp1[resp1.index(names[0]) + len(names[0]):].strip().split(' ')
    elif (names[1] in resp1):
        force_inp = resp1[:resp1.index(names[1])] + names[0]
        remainder = resp1[resp1.index(names[1]) + len(names[1]):].strip().split(' ')
    else:
        return [resp1, 'Name not present']
    if len(remainder) > 0:
        if remainder[0] in ['is', 'are', 'was', 'were']:
            force_inp += ' ' + ' '.join(remainder[:2])
        else:
            force_inp += ' ' + remainder[0]
    alt = inp + ' ' + force_inp
    input_ids2 = torch.tensor([tokenizer.encode(alt)])
    output2 = model.generate(input_ids2, max_new_tokens=30, min_length=30, do_sample=True)
    resp2 = tokenizer.decode(output2[0], skip_special_tokens=True)
    resp2 = resp2[resp2.index('%') + 1 : ]
    return [resp1, resp2]

io = gr.Interface(fn=hello,
    inputs=[
        gr.inputs.Textbox(label="WinoWhy Prompt"),
        gr.inputs.Textbox(label="Answers (optional)"),
    ],
    outputs=[
	   gr.outputs.Textbox(label="Fine-tuned reply"),
	   gr.outputs.Textbox(label="Alternative reply"),
    ],
    verbose=True,
    title='Anti-Explanations',
    description='Learn more at https://medium.com/nerd-for-tech/searching-for-anti-explanations-418d26816b44',
    #thumbnail='https://github.com/MonsoonNLP/gradio-gptnyc',
    analytics_enabled=True)

io.launch(debug=True)