Spaces:
Running
Running
monsoon-nlp
commited on
Commit
•
3e79d11
1
Parent(s):
5d3aa5d
include verb after name
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ def hello(prompt, items):
|
|
11 |
inp += '.'
|
12 |
inp += ' %'
|
13 |
input_ids = torch.tensor([tokenizer.encode(inp)])
|
14 |
-
output = model.generate(input_ids, max_length=
|
15 |
resp = tokenizer.decode(output[0], skip_special_tokens=True)
|
16 |
if '%' in resp:
|
17 |
resp1 = resp[resp.index('%') + 1 : ]
|
@@ -30,8 +30,17 @@ def hello(prompt, items):
|
|
30 |
|
31 |
if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
|
32 |
force_inp = resp1[:resp1.index(names[0])] + names[1]
|
33 |
-
|
|
|
34 |
force_inp = resp1[:resp1.index(names[1])] + names[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
alt = inp + ' ' + force_inp
|
36 |
input_ids2 = torch.tensor([tokenizer.encode(alt)])
|
37 |
output2 = model.generate(input_ids2, max_new_tokens=12, do_sample=True)
|
|
|
11 |
inp += '.'
|
12 |
inp += ' %'
|
13 |
input_ids = torch.tensor([tokenizer.encode(inp)])
|
14 |
+
output = model.generate(input_ids, max_length=35)
|
15 |
resp = tokenizer.decode(output[0], skip_special_tokens=True)
|
16 |
if '%' in resp:
|
17 |
resp1 = resp[resp.index('%') + 1 : ]
|
|
|
30 |
|
31 |
if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))):
|
32 |
force_inp = resp1[:resp1.index(names[0])] + names[1]
|
33 |
+
remainder = resp1[resp1.index(names[0]) + len(names[0]):].strip().split(' ')
|
34 |
+
elif (names[1] in resp1):
|
35 |
force_inp = resp1[:resp1.index(names[1])] + names[0]
|
36 |
+
remainder = resp1[resp1.index(names[1]) + len(names[1]):].strip().split(' ')
|
37 |
+
else:
|
38 |
+
return [resp1, 'Name not present']
|
39 |
+
if len(remainder) > 0:
|
40 |
+
if remainder[0] in ['is', 'are', 'was', 'were']:
|
41 |
+
force_inp += ' ' + ' '.join(remainder[:2])
|
42 |
+
else:
|
43 |
+
force_inp += ' ' + remainder[0]
|
44 |
alt = inp + ' ' + force_inp
|
45 |
input_ids2 = torch.tensor([tokenizer.encode(alt)])
|
46 |
output2 = model.generate(input_ids2, max_new_tokens=12, do_sample=True)
|