Spaces:
Runtime error
Runtime error
import csv | |
import spacy | |
import glob | |
import datetime | |
import pandas as pd | |
import gradio as gr | |
from transformers import pipeline | |
from huggingface_hub import hf_hub_download | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
models = { | |
'en': 'facebook/bart-large-mnli' | |
} | |
hypothesis_templates = { | |
'en': 'This example is {}.' | |
} | |
classifiers = {'en': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['en'], | |
model=models['en']) | |
} | |
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli') | |
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli') | |
labels=["contradicts_hypothesis","Neutral","Entails_hypothesis"] | |
def prep_examples(): | |
example_text1 = "EMI can be skipped" | |
example_labels1 = "EMI can be skipped" #"Entails Hypothisis" | |
example_text2 = "minimum package guranteed" | |
example_labels2 = "minimum package guranteed" #"Entails Hypothisis" | |
example_text3 = "100% placement gurantee" | |
example_labels3 = "100% placement gurantee" #"Entails Hypothisis" | |
#example_text1 = "EMI can not be skipped" | |
#example_labels1 = "contradicts_hypothesis" | |
examples = [ | |
[example_text1, example_labels1, False], | |
[example_text2, example_labels2, False], | |
[example_text3, example_labels3, False]] | |
return examples | |
def inference_hypothesis(premise,hypothesis,labels): | |
x = tokenizer.encode(premise, hypothesis, return_tensors='pt',truncation_strategy='only_first') | |
logits = nli_model(x.to("cpu"))[0] | |
entail_contradiction_logits = logits[:,[0,1,2]] | |
probs = entail_contradiction_logits.softmax(dim=1) | |
return premise,hypothesis,labels[probs.argmax()],entail_contradiction_logits | |
def sequence_to_classify(sequence, hypothesis_df, multi_label): | |
hypothesis_df=pd.read_csv(hypothesis_df.name) | |
lang = 'en' | |
classifier = classifiers[lang] | |
inference_output={} | |
label_clean = str(labels).split(";;") | |
for i,keyword in enumerate(hypothesis_df.filtering_keyword.tolist()): | |
if keyword.lower() in sequence.lower(): | |
output = inference_hypothesis(sequence, hypothesis_df.hypothesis.tolist()[i],labels) | |
if output[2]==hypothesis_df.expected_inference.tolist()[i]: | |
inference_output[output[0]]={"hypothesis":output[1],"label":output[2],"score":output[3]} | |
#inference_output.append(output) | |
return inference_output | |
predicted_labels = response['labels'] | |
predicted_scores = response['scores'] | |
clean_output = {idx: float(predicted_scores.pop(0)) for idx in predicted_labels} | |
print("Date:{}, Sequence:{}, Labels: {}".format( | |
str(datetime.datetime.now()), | |
sequence, | |
predicted_labels)) | |
if not multi_label: | |
top_label_key = list(clean_output.keys())[0] | |
clean_output = {top_label_key: clean_output[top_label_key]} | |
return clean_output | |
def csv_to_df(file): | |
return pd.read_csv(file) | |
def csv_to_json(df): | |
return df.to_json(orient="records") | |
iface = gr.Interface( | |
title="Sales Call Analysis AI - NS AI LABS", | |
description="Off-the-shelf NLP classifier with no domain or task-specific training.", | |
fn=sequence_to_classify, | |
inputs=[gr.inputs.Textbox(lines=10, | |
label="Please enter the text you would like to classify...", | |
placeholder="Text here..."), | |
gr.inputs.File(), | |
gr.inputs.Radio(choices=[False, True], | |
label="Multi-label?")], | |
outputs=gr.outputs.Textbox()) | |
iface.launch() |