Classifier / app.py
Taranosaurus's picture
Small efficiency changes on the classification returns and extra info
96eea2a
raw
history blame contribute delete
No virus
2.93 kB
from transformers import pipeline
import gradio as gr
import torch
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
summary_checkpoint = "facebook/bart-large-cnn" #"google/pegasus-large"
oracle_checkpoint = "facebook/bart-large-mnli"
summary = pipeline(task="summarization", model=summary_checkpoint, device=device)
oracle = pipeline(task="zero-shot-classification", model=oracle_checkpoint, device=device)
labels = ["merge","revert","fix","feature","update","refactor","test","security","documentation","style"]
selected_labels = ["feature","update","refactor","test","security","documentation","style"]
def do_the_thing(input, hypothesis, labels):
if hypothesis == None or hypothesis == "" or '{}' not in hypothesis:
hypothesis= "This example is {}."
summarisation = summary(input, truncation=True)[0]['summary_text']
zsc_results = oracle(sequences=[input, summarisation], candidate_labels=labels, multi_label=False, batch_size=2, hypothesis_template=hypothesis)
classifications_summary, classifications_input = {}, {}
for i in range(len(labels)):
classifications_input.update({zsc_results[0]['labels'][i]: zsc_results[0]['scores'][i]})
classifications_summary.update({zsc_results[1]['labels'][i]: zsc_results[1]['scores'][i]})
i+=1
return [summarisation, classifications_input, classifications_summary]
with gr.Blocks() as frontend:
gr.Markdown(f"## Git Commit Classifier\n\nThis tool is to take the notes from a commit, summarise and classify the original and the summary.\n\nTo get the git commit notes, clone the repo and the run `git --no-pager log --all --pretty='format:%B%n────%n'`")
input_value = gr.TextArea(label="Notes to Summarise")
btn_submit = gr.Button(value="Summarise and Classify")
with gr.Row():
with gr.Column():
input_labels = gr.Dropdown(label="Classification Labels", choices=labels, multiselect=True, value=selected_labels, interactive=True, allow_custom_value=True, info="Labels to classify the original text and summary. Select more or add your own.")
input_hypothesis = gr.Textbox(label="Hypothesis Template", info="This must include the {} format syntax. Blank and invalid inputs get defaulted to the palceholder text.", value="This git commit relates to {} changes.", placeholder="This example is {}.")
with gr.Column():
output_summary_text = gr.TextArea(label="Summary of Notes")
with gr.Row():
with gr.Column():
output_original_labels = gr.Label(label="Original Text Classification")
with gr.Column():
output_summary_labels = gr.Label(label="Summary Text Classification")
btn_submit.click(fn=do_the_thing, inputs=[input_value, input_hypothesis, input_labels], outputs=[output_summary_text, output_original_labels, output_summary_labels])
frontend.launch()