rocket-yg commited on
Commit
e733676
1 Parent(s): c7f3d38

added zero shot

Browse files
Files changed (1) hide show
  1. app.py +31 -3
app.py CHANGED
@@ -3,7 +3,8 @@ from transformers import pipeline
3
 
4
  get_completion = pipeline("summarization",model="sshleifer/distilbart-cnn-12-6")
5
  get_ner = pipeline("ner", model="dslim/bert-base-NER")
6
- get_caption = pipeline("image-to-text")
 
7
  def summarize_text(input):
8
  output = get_completion(input)
9
  return output[0]['summary_text']
@@ -27,6 +28,17 @@ def named_entity_recognition(input):
27
  merged_output = merge_tokens(output)
28
  return {"text": input, "entities": output}
29
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  interface_summarise = gr.Interface(fn=summarize_text,
32
  inputs=[gr.Textbox(label="Text to summarise", lines=5)],
@@ -45,11 +57,27 @@ interface_ner = gr.Interface(fn=named_entity_recognition,
45
  "My name is Bose and I am a physicist living in Delhi"
46
  ])
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  demo = gr.TabbedInterface([
49
  interface_summarise,
50
- interface_ner],
 
51
  ["Text Summary ",
52
- "Named Entity Recognition"
 
53
  ])
54
 
55
  if __name__ == "__main__":
 
3
 
4
  get_completion = pipeline("summarization",model="sshleifer/distilbart-cnn-12-6")
5
  get_ner = pipeline("ner", model="dslim/bert-base-NER")
6
+ get_zero = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
7
+
8
  def summarize_text(input):
9
  output = get_completion(input)
10
  return output[0]['summary_text']
 
28
  merged_output = merge_tokens(output)
29
  return {"text": input, "entities": output}
30
 
31
+ def zero_shot_pred(text,check_labels):
32
+ output = get_zero(text,check_labels)
33
+ return output
34
+
35
+ def label_score_dict(text,check_labels):
36
+ zero_shot_out = zero_shot_pred(text,check_labels)
37
+ out = {}
38
+ for i,j in zip(zero_shot_out['labels'],zero_shot_out['scores']):
39
+ out.update({i:j})
40
+ print(out)
41
+ return out
42
 
43
  interface_summarise = gr.Interface(fn=summarize_text,
44
  inputs=[gr.Textbox(label="Text to summarise", lines=5)],
 
57
  "My name is Bose and I am a physicist living in Delhi"
58
  ])
59
 
60
+ interface_zero_shot=gr.Interface(fn=label_score_dict,
61
+ inputs=[
62
+ gr.Textbox(label="Text to classify", lines=2),
63
+ gr.Textbox(label="Check for labels")
64
+ ],
65
+ outputs=gr.Label(num_top_classes=4),
66
+ title="Zero-Shot Preds using DeBERTa-v3-base-mnli",
67
+ description="Classify sentence on self defined target vars",
68
+ examples=[
69
+ ["Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.",
70
+ "mobile, website, billing, account access"],
71
+ # "My name is Bose and I am a physicist living in Delhi"
72
+ ])
73
+
74
  demo = gr.TabbedInterface([
75
  interface_summarise,
76
+ interface_ner,
77
+ interface_zero_shot],
78
  ["Text Summary ",
79
+ "Named Entity Recognition",
80
+ "Zero Shot Classifications"
81
  ])
82
 
83
  if __name__ == "__main__":