monsoon-nlp commited on
Commit
178c299
1 Parent(s): 5202911

formulate 2nd forced sentence

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -6,9 +6,12 @@ tokenizer = GPT2Tokenizer.from_pretrained("monsoon-nlp/gpt-winowhy")
6
  model = GPT2LMHeadModel.from_pretrained("monsoon-nlp/gpt-winowhy", pad_token_id=tokenizer.eos_token_id)
7
 
8
  def hello(prompt, items):
9
- inp = prompt.strip() + ' %'
 
 
 
10
  input_ids = torch.tensor([tokenizer.encode(inp)])
11
- output = model.generate(input_ids, max_new_tokens=12)
12
  resp = tokenizer.decode(output[0], skip_special_tokens=True)
13
  if '%' in resp:
14
  resp1 = resp[resp.index('%') + 1 : ]
@@ -25,11 +28,11 @@ def hello(prompt, items):
25
  # remove first one which assumedly is a capital
26
  names = names[1:]
27
 
28
- #if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
29
- # force_inp = inp + resp1[resp1.index(names[0]):] + names[1]
30
- #else:
31
- # force_inp = inp + resp1[resp1.index(names[1]):] + names[0]
32
- resp2 = ",".join(names) #force_inp
33
  #
34
  # input_ids2 = torch.tensor([tokenizer.encode(force_inp)])
35
  # output2 = model.generate(input_ids2, max_new_tokens=8)
 
6
  model = GPT2LMHeadModel.from_pretrained("monsoon-nlp/gpt-winowhy", pad_token_id=tokenizer.eos_token_id)
7
 
8
  def hello(prompt, items):
9
+ inp = prompt.strip()
10
+ if inp[-1] not in ['?', '!', '.']:
11
+ inp += '.'
12
+ inp += ' %'
13
  input_ids = torch.tensor([tokenizer.encode(inp)])
14
+ output = model.generate(input_ids, max_new_tokens=20)
15
  resp = tokenizer.decode(output[0], skip_special_tokens=True)
16
  if '%' in resp:
17
  resp1 = resp[resp.index('%') + 1 : ]
 
28
  # remove first one which assumedly is a capital
29
  names = names[1:]
30
 
31
+ if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
32
+ force_inp = resp1[resp1.index(names[0]):] + names[1]
33
+ else:
34
+ force_inp = resp1[resp1.index(names[1]):] + names[0]
35
+ resp2 = force_inp
36
  #
37
  # input_ids2 = torch.tensor([tokenizer.encode(force_inp)])
38
  # output2 = model.generate(input_ids2, max_new_tokens=8)