youj2005 commited on
Commit
7a7170b
1 Parent(s): bb13f04

Exponentiation

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -17,7 +17,7 @@ def predict(context, intent, multi_class):
17
  input_text = "In one word, what is the following describing: " + context
18
  input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
19
  object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
20
- batch = ['I think the ' + object_output + ' are long.', 'I think the ' + object_output + ' are ' + opposite_output, 'I think the ' + object_output + ' are the perfect']
21
  outputs = []
22
  for i, hypothesis in enumerate(batch):
23
  input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
@@ -35,10 +35,12 @@ def predict(context, intent, multi_class):
35
  outputs[2] = outputs[2].flip(dims=[0])
36
  # -> [entailment, neutral, contradiction]
37
  outputs[0] = outputs[0].flip(dims=[0])
38
- pn_tensor = (outputs[0] + outputs[1]).softmax(dim=0)
39
  pn_tensor[1] = pn_tensor[1] * outputs[2][0]
40
  pn_tensor[2] = pn_tensor[2] * outputs[2][1]
41
  pn_tensor[0] = pn_tensor[0] * outputs[2][1]
 
 
42
 
43
  if (multi_class):
44
  pn_tensor = torch.sigmoid(pn_tensor)
 
17
  input_text = "In one word, what is the following describing: " + context
18
  input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
19
  object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
20
+ batch = ['I think the ' + object_output + ' is ' + intent, 'I think the ' + object_output + ' is ' + opposite_output, 'I think the ' + object_output + ' are neither ' + intent + ' nor ' + opposite_output]
21
  outputs = []
22
  for i, hypothesis in enumerate(batch):
23
  input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
 
35
  outputs[2] = outputs[2].flip(dims=[0])
36
  # -> [entailment, neutral, contradiction]
37
  outputs[0] = outputs[0].flip(dims=[0])
38
+ pn_tensor = (outputs[0] + outputs[1])/2
39
  pn_tensor[1] = pn_tensor[1] * outputs[2][0]
40
  pn_tensor[2] = pn_tensor[2] * outputs[2][1]
41
  pn_tensor[0] = pn_tensor[0] * outputs[2][1]
42
+
43
+ pn_tensor = pn_tensor.exp() - 1
44
 
45
  if (multi_class):
46
  pn_tensor = torch.sigmoid(pn_tensor)