balamurugan's picture
Update app.py
9eb7aab
raw
history blame contribute delete
No virus
2.46 kB
import nltk
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
max_sequence_length = 512
embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npy'
nltk.download('punkt')
filename = 'gs_10k_2021.txt'
import os
textfile = open(filename,'r')
text_corpus=textfile.read()
corpus = []
sentence_count = []
sentences = nltk.tokenize.sent_tokenize(text_corpus, language='english')
sentence_count.append(len(sentences))
for _,s in enumerate(sentences):
corpus.append(s)
print(f'Number of sentences: {len(corpus)}')
# Load pre-embedded corpus
corpus_embeddings = np.load("df10k_embeddings_msmarco-distilbert-base-v4.npy")
print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
def find_sentences(query, hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
hits = hits[0]
print(hits)
print(hits)
output = pd.DataFrame(columns=['Text', 'Score'])
for hit in hits:
corpus_id = hit['corpus_id']
# Find source document based on sentence index
count = 0
new_row = {
'Text': corpus[corpus_id],
'Score': '{:.2f}'.format(hit['score'])
}
output = output.append(new_row, ignore_index=True)
print(output)
return output
def process( query):
text = query
return text, find_sentences(text, 2)
# if __name__ == "__main__":
# print(process("Great Opportunity in business"))
# print(process("LIBOR replacement"))
# print(process("Marquee"))
# Gradio inputs
text_query = gr.inputs.Textbox(lines=1, label='Text input', default='Great Opportunity')
# Gradio outputs
speech_query = gr.outputs.Textbox(type='auto', label='Query string')
results = gr.outputs.Dataframe(
headers=[ 'Text', 'Score'],
label='Query results')
iface = gr.Interface(
theme='huggingface',
description='',
fn=process,
inputs=[text_query],
outputs=[speech_query, results],
examples=[
['Great Opportunity in business'],
['LIBOR replacement'],
['Structured products'],
],
allow_flagging=False
)
iface.launch()