Taranosaurus commited on
Commit
a5709b4
1 Parent(s): d38122f

Added faster model switching and truncation to prevent errors on long inputs

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import pipeline
2
  import gradio as gr
3
  import torch
4
 
@@ -7,13 +7,18 @@ if torch.cuda.is_available():
7
  else:
8
  device = torch.device("cpu")
9
 
10
- summary = pipeline(task="summarization", model="facebook/bart-large-cnn", device=device)
11
- oracle = pipeline(task="zero-shot-classification", model="facebook/bart-large-mnli", device=device)
 
 
 
 
12
  labels = ["merge","revert","fix","feature","update","refactor","test","security","documentation","style"]
 
13
 
14
  def do_the_thing(input, labels):
15
  #print(labels)
16
- summarisation = summary(input)[0]['summary_text']
17
  zsc_results = oracle(sequences=[input, summarisation], candidate_labels=labels, multi_label=False, batch_size=2)
18
  classifications_input = {}
19
  for i in range(len(labels)):
@@ -32,7 +37,7 @@ with gr.Blocks() as frontend:
32
  btn_submit = gr.Button(value="Summarise and Classify")
33
  with gr.Row():
34
  with gr.Column():
35
- input_labels = gr.Dropdown(label="Classification Labels", choices=labels, multiselect=True, value=labels, interactive=True, allow_custom_value=True, info="Labels to classify the original text and summary")
36
  with gr.Column():
37
  output_summary_text = gr.TextArea(label="Summary of Notes")
38
  with gr.Row():
 
1
+ from transformers import pipeline, AutoTokenizer
2
  import gradio as gr
3
  import torch
4
 
 
7
  else:
8
  device = torch.device("cpu")
9
 
10
+ summary_checkpoint = "facebook/bart-large-cnn" #"google/pegasus-large"
11
+ oracle_checkpoint = "facebook/bart-large-mnli"
12
+ tokenizer = AutoTokenizer.from_pretrained(summary_checkpoint, device=device)
13
+ summary = pipeline(task="summarization", model=summary_checkpoint, tokenizer=tokenizer, device=device)
14
+
15
+ oracle = pipeline(task="zero-shot-classification", model=oracle_checkpoint, device=device)
16
  labels = ["merge","revert","fix","feature","update","refactor","test","security","documentation","style"]
17
+ selected_labels = ["feature","update","refactor","test","security","documentation","style"]
18
 
19
  def do_the_thing(input, labels):
20
  #print(labels)
21
+ summarisation = summary(input, truncation=True)[0]['summary_text']
22
  zsc_results = oracle(sequences=[input, summarisation], candidate_labels=labels, multi_label=False, batch_size=2)
23
  classifications_input = {}
24
  for i in range(len(labels)):
 
37
  btn_submit = gr.Button(value="Summarise and Classify")
38
  with gr.Row():
39
  with gr.Column():
40
+ input_labels = gr.Dropdown(label="Classification Labels", choices=labels, multiselect=True, value=selected_labels, interactive=True, allow_custom_value=True, info="Labels to classify the original text and summary")
41
  with gr.Column():
42
  output_summary_text = gr.TextArea(label="Summary of Notes")
43
  with gr.Row():