Anustup commited on
Commit
0919045
1 Parent(s): 646d665

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import spacy
3
+ import glob
4
+ import datetime
5
+ import pandas as pd
6
+ import gradio as gr
7
+ from transformers import pipeline
8
+ from huggingface_hub import hf_hub_download
9
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
+
11
+ models = {
12
+ 'en': 'facebook/bart-large-mnli'
13
+ }
14
+
15
+ hypothesis_templates = {
16
+ 'en': 'This example is {}.'
17
+ }
18
+
19
+ classifiers = {'en': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['en'],
20
+ model=models['en'])
21
+ }
22
+
23
+ nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
24
+ tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
25
+ labels=["contradicts_hypothesis","Neutral","Entails_hypothesis"]
26
+
27
+ def prep_examples():
28
+ example_text1 = "EMI can be skipped"
29
+ example_labels1 = "EMI can be skipped" #"Entails Hypothisis"
30
+
31
+ example_text2 = "minimum package guranteed"
32
+ example_labels2 = "minimum package guranteed" #"Entails Hypothisis"
33
+
34
+ example_text3 = "100% placement gurantee"
35
+ example_labels3 = "100% placement gurantee" #"Entails Hypothisis"
36
+
37
+ #example_text1 = "EMI can not be skipped"
38
+ #example_labels1 = "contradicts_hypothesis"
39
+
40
+
41
+ examples = [
42
+ [example_text1, example_labels1, False],
43
+ [example_text2, example_labels2, False],
44
+ [example_text3, example_labels3, False]]
45
+
46
+ return examples
47
+
48
+
49
+ def inference_hypothesis(premise,hypothesis,labels):
50
+ x = tokenizer.encode(premise, hypothesis, return_tensors='pt',truncation_strategy='only_first')
51
+ logits = nli_model(x.to("cpu"))[0]
52
+ entail_contradiction_logits = logits[:,[0,1,2]]
53
+ probs = entail_contradiction_logits.softmax(dim=1)
54
+ return premise,hypothesis,labels[probs.argmax()],entail_contradiction_logits
55
+
56
+ def sequence_to_classify(sequence, hypothesis_df, multi_label):
57
+ hypothesis_df=pd.read_csv(hypothesis_df.name)
58
+ lang = 'en'
59
+ classifier = classifiers[lang]
60
+ inference_output={}
61
+ label_clean = str(labels).split(";;")
62
+ for i,keyword in enumerate(hypothesis_df.filtering_keyword.tolist()):
63
+ if keyword.lower() in sequence.lower():
64
+ output = inference_hypothesis(sequence, hypothesis_df.hypothesis.tolist()[i],labels)
65
+ if output[2]==hypothesis_df.expected_inference.tolist()[i]:
66
+ inference_output[output[0]]={"hypothesis":output[1],"label":output[2],"score":output[3]}
67
+ #inference_output.append(output)
68
+ return inference_output
69
+ predicted_labels = response['labels']
70
+ predicted_scores = response['scores']
71
+ clean_output = {idx: float(predicted_scores.pop(0)) for idx in predicted_labels}
72
+ print("Date:{}, Sequence:{}, Labels: {}".format(
73
+ str(datetime.datetime.now()),
74
+ sequence,
75
+ predicted_labels))
76
+
77
+ if not multi_label:
78
+ top_label_key = list(clean_output.keys())[0]
79
+ clean_output = {top_label_key: clean_output[top_label_key]}
80
+ return clean_output
81
+ def csv_to_df(file):
82
+ return pd.read_csv(file)
83
+
84
+ def csv_to_json(df):
85
+ return df.to_json(orient="records")
86
+
87
+ iface = gr.Interface(
88
+ title="Sales Call Analysis AI - NS AI LABS",
89
+ description="Off-the-shelf NLP classifier with no domain or task-specific training.",
90
+ fn=sequence_to_classify,
91
+ inputs=[gr.inputs.Textbox(lines=10,
92
+ label="Please enter the text you would like to classify...",
93
+ placeholder="Text here..."),
94
+ gr.inputs.File(),
95
+ gr.inputs.Radio(choices=[False, True],
96
+ label="Multi-label?")],
97
+ outputs=gr.outputs.Textbox())
98
+
99
+ iface.launch()