StaticFace commited on
Commit
c205f25
·
verified ·
1 Parent(s): a628211

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -20
app.py CHANGED
@@ -1,25 +1,34 @@
1
  import os
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
5
 
6
  MODEL_ID = "MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
7
 
8
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
9
- torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", "2")))
10
  torch.set_num_interop_threads(1)
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
13
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
14
  model.eval()
15
 
16
- clf = pipeline(
17
- task="zero-shot-classification",
18
- model=model,
19
- tokenizer=tokenizer,
20
- device=-1,
21
- framework="pt",
22
- )
23
 
24
  def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
25
  text = (text or "").strip()
@@ -33,22 +42,30 @@ def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
33
  if not candidate_labels:
34
  return {"error": "Enter at least 1 label (comma-separated)."}
35
 
 
36
  with torch.inference_mode():
37
- out = clf(
38
- sequences=text,
39
- candidate_labels=candidate_labels,
40
- hypothesis_template=hypothesis_template,
41
- multi_label=bool(multi_label),
42
- )
43
-
44
- pairs = list(zip(out["labels"], out["scores"]))
 
 
 
 
 
 
 
45
  pairs.sort(key=lambda x: x[1], reverse=True)
46
  pairs = pairs[: max(1, int(top_k))]
47
 
48
  return {
 
49
  "top": {"label": pairs[0][0], "confidence_pct": round(pairs[0][1] * 100, 2)},
50
  "all": [{"label": k, "confidence_pct": round(v * 100, 2)} for k, v in pairs],
51
- "raw": out,
52
  }
53
 
54
  demo = gr.Interface(
@@ -61,7 +78,7 @@ demo = gr.Interface(
61
  gr.Slider(label="Top-K to show", minimum=1, maximum=25, value=5, step=1),
62
  ],
63
  outputs=gr.JSON(label="Output"),
64
- title="Zero-Shot Classification (DeBERTa v3 Large, MoritzLaurer)",
65
  flagging_mode="never",
66
  )
67
 
 
1
  import os
2
+
3
+ CPU_THREADS = 16
4
+
5
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
6
+ os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
7
+ os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
8
+ os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS)
9
+ os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS)
10
+
11
  import torch
12
  import gradio as gr
13
+ import numpy as np
14
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
15
 
16
  MODEL_ID = "MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
17
 
18
+ torch.set_num_threads(CPU_THREADS)
 
19
  torch.set_num_interop_threads(1)
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
22
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
23
  model.eval()
24
 
25
+ label2id = {k.lower(): v for k, v in model.config.label2id.items()}
26
+ entail_id = label2id.get("entailment", 2)
27
+
28
+ def _softmax(x):
29
+ x = x - np.max(x)
30
+ e = np.exp(x)
31
+ return e / np.sum(e)
32
 
33
  def run_zero_shot(text, labels, hypothesis_template, multi_label, top_k):
34
  text = (text or "").strip()
 
42
  if not candidate_labels:
43
  return {"error": "Enter at least 1 label (comma-separated)."}
44
 
45
+ scores = []
46
  with torch.inference_mode():
47
+ for lab in candidate_labels:
48
+ hyp = hypothesis_template.format(lab)
49
+ inputs = tokenizer(text, hyp, return_tensors="pt", truncation=True)
50
+ logits = model(**inputs).logits[0].float().cpu().numpy()
51
+ score = float(_softmax(logits)[entail_id])
52
+ scores.append(score)
53
+
54
+ scores_np = np.array(scores, dtype=np.float32)
55
+
56
+ if bool(multi_label):
57
+ out_scores = scores_np
58
+ else:
59
+ out_scores = _softmax(scores_np)
60
+
61
+ pairs = list(zip(candidate_labels, out_scores.tolist()))
62
  pairs.sort(key=lambda x: x[1], reverse=True)
63
  pairs = pairs[: max(1, int(top_k))]
64
 
65
  return {
66
+ "cpu_threads": CPU_THREADS,
67
  "top": {"label": pairs[0][0], "confidence_pct": round(pairs[0][1] * 100, 2)},
68
  "all": [{"label": k, "confidence_pct": round(v * 100, 2)} for k, v in pairs],
 
69
  }
70
 
71
  demo = gr.Interface(
 
78
  gr.Slider(label="Top-K to show", minimum=1, maximum=25, value=5, step=1),
79
  ],
80
  outputs=gr.JSON(label="Output"),
81
+ title="Zero-Shot Classification (DeBERTa v3 Large, 16-core CPU)",
82
  flagging_mode="never",
83
  )
84