jorgefio commited on
Commit
1ea99dc
1 Parent(s): cb45be7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -6,6 +6,7 @@ import time
6
 
7
  zero_shot = pipeline("zero-shot-classification")
8
  distilbert = ktrain.load_predictor("models/distilbert-base-uncased-finetuned-internet-provider")
 
9
 
10
  def zero_shot_predict(text):
11
  labels = ["Slow Connection", "Billing", "Setup", "No Connectivity"]
@@ -17,17 +18,24 @@ def distilbert_predict(text):
17
  preds = distilbert.predict_proba(text)
18
  return {label: float(pred) for label, pred in zip(labels, preds)}
19
 
 
 
 
 
 
20
  def predict(text):
21
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
22
  zero_shot_future = executor.submit(zero_shot_predict, text)
23
  distilbert_future = executor.submit(distilbert_predict, text)
24
- concurrent.futures.wait([zero_shot_future, distilbert_future])
 
25
  zero_shot_preds = zero_shot_future.result()
26
  distilbert_preds = distilbert_future.result()
27
- return zero_shot_preds, distilbert_preds
 
28
 
29
  input = gr.inputs.Textbox(label="Customer Sentence")
30
- outputs = [gr.outputs.Label(num_top_classes=4, label="Zero-Shot-Classification"), gr.outputs.Label(num_top_classes=4, label="DistilBERT")]
31
  title = "Case Classification"
32
  description = "Comparison of Zero-Shot-Classification and a fine-tuned DistilBERT."
33
  gr.Interface(predict, input, outputs, live=False, live_update=False, title=title, analytics_enabled=False,
 
6
 
7
  zero_shot = pipeline("zero-shot-classification")
8
  distilbert = ktrain.load_predictor("models/distilbert-base-uncased-finetuned-internet-provider")
9
+ distilbert_v2 = ktrain.load_predictor("models/distilbert-base-uncased-finetuned-internet-provider")
10
 
11
  def zero_shot_predict(text):
12
  labels = ["Slow Connection", "Billing", "Setup", "No Connectivity"]
 
18
  preds = distilbert.predict_proba(text)
19
  return {label: float(pred) for label, pred in zip(labels, preds)}
20
 
21
+ def distilbert_v2_predict(text):
22
+ labels = distilbert_v2.get_classes()
23
+ preds = distilbert_v2.predict_proba(text)
24
+ return {label: float(pred) for label, pred in zip(labels, preds)}
25
+
26
  def predict(text):
27
+ with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
28
  zero_shot_future = executor.submit(zero_shot_predict, text)
29
  distilbert_future = executor.submit(distilbert_predict, text)
30
+ distilbert_v2_future = executor.submit(distilbert_v2_predict, text)
31
+ concurrent.futures.wait([zero_shot_future, distilbert_future, distilbert_v2_future])
32
  zero_shot_preds = zero_shot_future.result()
33
  distilbert_preds = distilbert_future.result()
34
+ distilbert_v2_preds = distilbert_v2_future.result()
35
+ return zero_shot_preds, distilbert_preds, distilbert_v2_preds
36
 
37
  input = gr.inputs.Textbox(label="Customer Sentence")
38
+ outputs = [gr.outputs.Label(num_top_classes=4, label="Zero-Shot-Classification"), gr.outputs.Label(num_top_classes=4, label="DistilBERT"), gr.outputs.Label(num_top_classes=4, label="DistilBERT v2")]
39
  title = "Case Classification"
40
  description = "Comparison of Zero-Shot-Classification and a fine-tuned DistilBERT."
41
  gr.Interface(predict, input, outputs, live=False, live_update=False, title=title, analytics_enabled=False,