monsoon-nlp commited on
Commit
3e79d11
1 Parent(s): 5d3aa5d

include verb after name

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -11,7 +11,7 @@ def hello(prompt, items):
11
  inp += '.'
12
  inp += ' %'
13
  input_ids = torch.tensor([tokenizer.encode(inp)])
14
- output = model.generate(input_ids, max_length=30)
15
  resp = tokenizer.decode(output[0], skip_special_tokens=True)
16
  if '%' in resp:
17
  resp1 = resp[resp.index('%') + 1 : ]
@@ -30,8 +30,17 @@ def hello(prompt, items):
30
 
31
  if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
32
  force_inp = resp1[:resp1.index(names[0])] + names[1]
33
- else:
 
34
  force_inp = resp1[:resp1.index(names[1])] + names[0]
 
 
 
 
 
 
 
 
35
  alt = inp + ' ' + force_inp
36
  input_ids2 = torch.tensor([tokenizer.encode(alt)])
37
  output2 = model.generate(input_ids2, max_new_tokens=12, do_sample=True)
 
11
  inp += '.'
12
  inp += ' %'
13
  input_ids = torch.tensor([tokenizer.encode(inp)])
14
+ output = model.generate(input_ids, max_length=35)
15
  resp = tokenizer.decode(output[0], skip_special_tokens=True)
16
  if '%' in resp:
17
  resp1 = resp[resp.index('%') + 1 : ]
 
30
 
31
  if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
32
  force_inp = resp1[:resp1.index(names[0])] + names[1]
33
+ remainder = resp1[resp1.index(names[0]) + len(names[0]):].strip().split(' ')
34
+ elif (names[1] in resp1):
35
  force_inp = resp1[:resp1.index(names[1])] + names[0]
36
+ remainder = resp1[resp1.index(names[1]) + len(names[1]):].strip().split(' ')
37
+ else:
38
+ return [resp1, 'Name not present']
39
+ if len(remainder) > 0:
40
+ if remainder[0] in ['is', 'are', 'was', 'were']:
41
+ force_inp += ' ' + ' '.join(remainder[:2])
42
+ else:
43
+ force_inp += ' ' + remainder[0]
44
  alt = inp + ' ' + force_inp
45
  input_ids2 = torch.tensor([tokenizer.encode(alt)])
46
  output2 = model.generate(input_ids2, max_new_tokens=12, do_sample=True)