youj2005 commited on
Commit
2a8fa62
1 Parent(s): 0b73704

Added examples and description

Browse files
Files changed (2) hide show
  1. app.py +18 -13
  2. gradio_cached_examples/18/log.csv +2 -0
app.py CHANGED
@@ -6,9 +6,9 @@ import torch
6
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
7
 
8
  te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
9
- te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
10
- qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
11
- qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto")
12
 
13
  def predict(context, intent, multi_class):
14
  input_text = "What is the opposite of " + intent + "?"
@@ -20,10 +20,9 @@ def predict(context, intent, multi_class):
20
  batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is not ' + intent, 'The ' + object_output + ' is not ' + opposite_output]
21
 
22
  outputs = []
23
- print(intent, opposite_output, object_output)
24
  for i, hypothesis in enumerate(batch):
25
- # input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
26
- input_ids = te_model(context, hypothesis).to(device)
27
  # -> [contradiction, neutral, entailment]
28
  logits = te_model(input_ids)[0][0]
29
 
@@ -36,8 +35,9 @@ def predict(context, intent, multi_class):
36
 
37
  # calculate the stochastic vector for it being neither the positive or negative class
38
  perfect_prob = [0, 0]
39
- perfect_prob[0] = (outputs[2][0] + outputs[3][1])/2
40
- perfect_prob[1] = 1-perfect_prob[2][0]
 
41
 
42
  # -> [entailment, neutral, contradiction] for positive
43
  outputs[0] = outputs[0].flip(dims=[0])
@@ -63,7 +63,8 @@ def predict(context, intent, multi_class):
63
  aggregated[0] = aggregated[0] * perfect_prob[0]
64
 
65
  # to exagerate differences
66
- aggregated = aggregated.exp() - 1
 
67
 
68
  # multiple true classes
69
  if (multi_class):
@@ -72,14 +73,18 @@ def predict(context, intent, multi_class):
72
  else:
73
  aggregated = aggregated.softmax(dim=0)
74
  aggregated = aggregated.tolist()
75
- return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}
 
 
76
 
77
  gradio_app = gr.Interface(
78
  predict,
79
- inputs=[gr.Text(label="Sentence"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
80
- outputs=[gr.Label(num_top_classes=3)],
 
81
  title="Intent Analysis",
82
- description="This model predicts whether or not the **_class_** describes the **_object described in the sentence._**"
 
83
  )
84
 
85
  gradio_app.launch(share=True)
 
6
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
7
 
8
  te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
9
+ te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli', device_map="auto")
10
+ qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
11
+ qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
12
 
13
  def predict(context, intent, multi_class):
14
  input_text = "What is the opposite of " + intent + "?"
 
20
  batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is not ' + intent, 'The ' + object_output + ' is not ' + opposite_output]
21
 
22
  outputs = []
 
23
  for i, hypothesis in enumerate(batch):
24
+ input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
25
+
26
  # -> [contradiction, neutral, entailment]
27
  logits = te_model(input_ids)[0][0]
28
 
 
35
 
36
  # calculate the stochastic vector for it being neither the positive or negative class
37
  perfect_prob = [0, 0]
38
+ perfect_prob[1] = max(float(outputs[2][0]), float(outputs[3][0]))
39
+ perfect_prob[0] = 1-perfect_prob[1]
40
+ # -> [entailment, contradiction] for perfect
41
 
42
  # -> [entailment, neutral, contradiction] for positive
43
  outputs[0] = outputs[0].flip(dims=[0])
 
63
  aggregated[0] = aggregated[0] * perfect_prob[0]
64
 
65
  # to exagerate differences
66
+ # this way 0 maps to 0
67
+ aggregated = aggregated.exp()-1
68
 
69
  # multiple true classes
70
  if (multi_class):
 
73
  else:
74
  aggregated = aggregated.softmax(dim=0)
75
  aggregated = aggregated.tolist()
76
+ return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}, {"agree": outputs[0][0], "neutral": outputs[0][1], "disagree": outputs[0][2]}
77
+
78
+ examples = [["The pants fit great, even the waist will fit me fine once I'm back to my normal weight, but the bottom is what's large. You can roll up the bottom part of the legs, or the top at the waist band for hanging out at the house, but if you have one nearby, simply have them re-hemmed.", "long"]]
79
 
80
  gradio_app = gr.Interface(
81
  predict,
82
+ examples=examples,
83
+ inputs=[gr.Text(label="Statement"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
84
+ outputs=[gr.Label(num_top_classes=3, label="With Postprocessing"), gr.Label(num_top_classes=3, label="Without Postprocessing")],
85
  title="Intent Analysis",
86
+ description="This model predicts whether or not the **_class_** describes the **_object described in the sentence_**. <br /> The two outputs shows what TE would predict with and without the postprocessing. An example edge case for normal TE is shown below. <br /> **_It is recommended that you clone the repository to speed up processing time_**.",
87
+ cache_examples=True
88
  )
89
 
90
  gradio_app.launch(share=True)
gradio_cached_examples/18/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ With Postprocessing,Without Postprocessing,flag,username,timestamp
2
+ "{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.37631523609161377},{""label"":""neutral"",""confidence"":0.3404143750667572},{""label"":""disagree"",""confidence"":0.28327038884162903}]}","{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.8370960354804993},{""label"":""disagree"",""confidence"":0.12820996344089508},{""label"":""agree"",""confidence"":0.03469394892454147}]}",,,2024-03-10 20:51:53.608441