monsoon-nlp commited on
Commit
051ee94
0 Parent(s):

initial demo of 2 outputs

Browse files
Files changed (3) hide show
  1. README.md +9 -0
  2. app.py +55 -0
  3. requirements.txt +2 -0
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Anti-Explanation
3
+ emoji: ☯️
4
+ colorFrom: back
5
+ colorTo: white
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+
5
+ tokenizer = GPT2Tokenizer.from_pretrained("monsoon-nlp/gpt-winowhy")
6
+ model = GPT2LMHeadModel.from_pretrained("monsoon-nlp/gpt-winowhy", pad_token_id=tokenizer.eos_token_id)
7
+
8
+ def hello(prompt, items):
9
+ inp = prompt.strip() + ' %'
10
+ input_ids = torch.tensor([tokenizer.encode(inp)])
11
+ output = model.generate(input_ids, max_new_tokens=12)
12
+ resp = tokenizer.decode(output[0], skip_special_tokens=True)
13
+ if '%' in resp:
14
+ resp1 = resp[resp.index('%') + 1 : ]
15
+
16
+ names = []
17
+ if ',' in items:
18
+ for item in items.split(','):
19
+ names.append(item.strip())
20
+ else:
21
+ for word in prompt.split(' '):
22
+ if word[0] == word[0].upper():
23
+ names.append(word)
24
+ if len(names) > 2:
25
+ # remove first one which assumedly is a capital
26
+ names = names[1:]
27
+
28
+ if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
29
+ force_inp = inp + resp1[resp1.index(names[0]):] + names[1]
30
+ else:
31
+ force_inp = inp + resp1[resp1.index(names[1]):] + names[0]
32
+ resp2 = force_inp
33
+ #
34
+ # input_ids2 = torch.tensor([tokenizer.encode(force_inp)])
35
+ # output2 = model.generate(input_ids2, max_new_tokens=8)
36
+ # resp = tokenizer.decode(output2[0], skip_special_tokens=True)
37
+ # resp2 = resp2[resp2.index('%') + 1 : ]
38
+ return [resp1, resp2]
39
+
40
+ io = gr.Interface(fn=hello,
41
+ inputs=[
42
+ gr.inputs.Textbox(label="WinoWhy Prompt"),
43
+ gr.inputs.Textbox(label="Answers (optional)"),
44
+ ],
45
+ outputs=[
46
+ gr.outputs.Textbox(label="Fine-tuned reply"),
47
+ gr.outputs.Textbox(label="Alternative reply"),
48
+ ],
49
+ verbose=True,
50
+ title='Anti-Explanations',
51
+ description='Learn more at https://medium.com/nerd-for-tech/searching-for-anti-explanations-418d26816b44',
52
+ #thumbnail='https://github.com/MonsoonNLP/gradio-gptnyc',
53
+ analytics_enabled=True)
54
+
55
+ io.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ torch
2
+ transformers==4.9.1