balamurugan commited on
Commit
8468f51
1 Parent(s): e37ea7c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import pickle
3
+ import pandas as pd
4
+ import gradio as gr
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer, util
7
+ from transformers import pipeline
8
+
9
+ model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
10
+ max_sequence_length = 512
11
+ embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npy'
12
+
13
+ nltk.download('punkt')
14
+ filename = 'gs_10k_2021.txt'
15
+
16
+ import os
17
+ textfile = open(filename,'r')
18
+ text_corpus=textfile.read()
19
+
20
+ corpus = []
21
+ sentence_count = []
22
+ sentences = nltk.tokenize.sent_tokenize(text_corpus, language='english')
23
+ sentence_count.append(len(sentences))
24
+ for _,s in enumerate(sentences):
25
+ corpus.append(s)
26
+ print(f'Number of sentences: {len(corpus)}')
27
+
28
+ # Load pre-embedded corpus
29
+ corpus_embeddings = np.load("df10k_embeddings_msmarco-distilbert-base-v4.npy")
30
+ print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
31
+
32
+ # Load embedding model
33
+ model = SentenceTransformer(model_name)
34
+ model.max_seq_length = max_sequence_length
35
+
36
+ def find_sentences(query, hits):
37
+ query_embedding = model.encode(query)
38
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
39
+ hits = hits[0]
40
+ print(hits)
41
+ print(hits)
42
+
43
+ output = pd.DataFrame(columns=['Text', 'Score'])
44
+ for hit in hits:
45
+ corpus_id = hit['corpus_id']
46
+ # Find source document based on sentence index
47
+ count = 0
48
+ new_row = {
49
+ 'Text': corpus[corpus_id],
50
+ 'Score': '{:.2f}'.format(hit['score'])
51
+ }
52
+ output = output.append(new_row, ignore_index=True)
53
+ print(output)
54
+ return output
55
+
56
+
57
+ def process( query):
58
+ text = query
59
+ return text, find_sentences(text, 2)
60
+
61
+ # if __name__ == "__main__":
62
+ # print(process("Great Opportunity in business"))
63
+ # print(process("LIBOR replacement"))
64
+ # print(process("Marquee"))
65
+
66
+
67
+
68
+ # Gradio inputs
69
+ text_query = gr.inputs.Textbox(lines=1, label='Text input', default='Great Opportunity')
70
+
71
+
72
+ # Gradio outputs
73
+ speech_query = gr.outputs.Textbox(type='auto', label='Query string')
74
+ results = gr.outputs.Dataframe(
75
+ headers=[ 'Text', 'Score'],
76
+ label='Query results')
77
+
78
+ iface = gr.Interface(
79
+ theme='huggingface',
80
+ description='Great Opportunity in business',
81
+ fn=process,
82
+ inputs=[text_query],
83
+ outputs=[speech_query, results],
84
+ examples=[
85
+ ['Great Opportunity in business'],
86
+ ['LIBOR replacement'],
87
+ ['Marquee'],
88
+ ],
89
+ allow_flagging=False
90
+ )
91
+ iface.launch()
92
+