jonathanjordan21 commited on
Commit
e711658
1 Parent(s): a3ee485

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -5
app.py CHANGED
@@ -1,7 +1,42 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- gr.Interface.load(
4
- "spaces/jonathanjordan21/lmd_chatbot_embedding",
5
- inputs=["text", gr.Slider(0.0, 1.0), "text", gr.Checkbox(label='Allow multiple true classes')],
6
- outputs="json"
7
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from InstructorEmbedding import INSTRUCTOR
7
 
8
+ # pipe = pipeline(model="facebook/bart-large-mnli")
9
+ pipe = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7")
10
+ model = INSTRUCTOR('hkunlp/instructor-large')
11
+
12
+ df = pd.read_csv('intent.csv', delimiter=';')
13
+
14
+ data = [
15
+ [
16
+ f'Represent the document for retrieval of {x["description"]} information : ',
17
+ x["message"]
18
+ ] for _,x in df.iterrows()
19
+ ]
20
+
21
+ corpus_embeddings = model.encode(data)
22
+
23
+
24
+ def predict(question, lower_threshold, tags, multi_label):
25
+ query = [['Represent the question for retrieving supporting documents: ',question]]
26
+ query_embeddings = model.encode(query)
27
+ similarities = cosine_similarity(query_embeddings,corpus_embeddings)
28
+ retrieved_doc_id = np.argmax(similarities)
29
+
30
+ if similarities[0][retrieved_doc_id] < float(lower_threshold):
31
+ ans = pipe(question, candidate_labels=[x.strip() for x in tags.split(",") if x.strip()!=""], multi_label=multi_label)
32
+ ans['query_similarity_score'] = similarities[0][retrieved_doc_id]
33
+ return ans
34
+ return {"chatbot_response" : data[retrieved_doc_id][-1], 'query_similarity_score' : similarities[0][retrieved_doc_id]}
35
+
36
+
37
+
38
+
39
+
40
+ gr.Interface(fn=predict,
41
+ inputs=["text", gr.Slider(0.0, 1.0), "text", gr.Checkbox(label='Allow multiple true classes')],
42
+ outputs="json").launch()