youj2005 commited on
Commit
bb13f04
1 Parent(s): 9bba617

Fix multi-class

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
  from transformers import BartForSequenceClassification, BartTokenizer
3
- import torch.nn.functional as F
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
 
 
 
5
 
6
  te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
7
  te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
@@ -10,15 +12,15 @@ qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", de
10
 
11
  def predict(context, intent, multi_class):
12
  input_text = "In one word, what is the opposite of: " + intent + "?"
13
- input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids
14
  opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
15
  input_text = "In one word, what is the following describing: " + context
16
- input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids
17
  object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
18
  batch = ['I think the ' + object_output + ' are long.', 'I think the ' + object_output + ' are ' + opposite_output, 'I think the ' + object_output + ' are the perfect']
19
  outputs = []
20
  for i, hypothesis in enumerate(batch):
21
- input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt')
22
  # -> [contradiction, neutral, entailment]
23
  logits = te_model(input_ids)[0][0]
24
 
@@ -38,20 +40,19 @@ def predict(context, intent, multi_class):
38
  pn_tensor[2] = pn_tensor[2] * outputs[2][1]
39
  pn_tensor[0] = pn_tensor[0] * outputs[2][1]
40
 
41
- pn_tensor = F.normalize(pn_tensor, p=1, dim=0)
42
  if (multi_class):
43
- pn_tensor = F.normalize(pn_tensor, p=1, dim=0)
44
  else:
45
  pn_tensor = pn_tensor.softmax(dim=0)
46
  pn_tensor = pn_tensor.tolist()
47
- return {"entailment": pn_tensor[0], "neutral": pn_tensor[1], "contradiction": pn_tensor[2]}
48
 
49
  gradio_app = gr.Interface(
50
  predict,
51
- inputs=[gr.Text("Sentence"), gr.Text("Class"), gr.Checkbox("Allow multiple true classes")],
52
  outputs=[gr.Label(num_top_classes=3)],
53
  title="Intent Analysis",
 
54
  )
55
 
56
- if __name__ == "__main__":
57
- gradio_app.launch()
 
1
  import gradio as gr
2
  from transformers import BartForSequenceClassification, BartTokenizer
 
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration
4
+ import torch
5
+
6
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
7
 
8
  te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
9
  te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
 
12
 
13
  def predict(context, intent, multi_class):
14
  input_text = "In one word, what is the opposite of: " + intent + "?"
15
+ input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
16
  opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
17
  input_text = "In one word, what is the following describing: " + context
18
+ input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
19
  object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
20
  batch = ['I think the ' + object_output + ' are long.', 'I think the ' + object_output + ' are ' + opposite_output, 'I think the ' + object_output + ' are the perfect']
21
  outputs = []
22
  for i, hypothesis in enumerate(batch):
23
+ input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
24
  # -> [contradiction, neutral, entailment]
25
  logits = te_model(input_ids)[0][0]
26
 
 
40
  pn_tensor[2] = pn_tensor[2] * outputs[2][1]
41
  pn_tensor[0] = pn_tensor[0] * outputs[2][1]
42
 
 
43
  if (multi_class):
44
+ pn_tensor = torch.sigmoid(pn_tensor)
45
  else:
46
  pn_tensor = pn_tensor.softmax(dim=0)
47
  pn_tensor = pn_tensor.tolist()
48
+ return {"agree": pn_tensor[0], "neutral": pn_tensor[1], "disagree": pn_tensor[2]}
49
 
50
  gradio_app = gr.Interface(
51
  predict,
52
+ inputs=[gr.Text(label="Sentence"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
53
  outputs=[gr.Label(num_top_classes=3)],
54
  title="Intent Analysis",
55
+ description="This model predicts whether or not the **class** describes the **object described in the sentence.**"
56
  )
57
 
58
+ gradio_app.launch()