Chitti / load_data.py
Pavankalyan's picture
Update load_data.py
8e74a0f
import pandas as pd
import os
import json
import re
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
import time
import textwrap
model_bi_encoder = "msmarco-distilbert-base-tas-b"
model_cross_encoder = "cross-encoder/ms-marco-MiniLM-L-12-v2"
bi_encoder = SentenceTransformer(model_bi_encoder)
bi_encoder.max_seq_length = 512
cross_encoder = CrossEncoder(model_cross_encoder)
def collect_data(data_lis,meta_count):
new_files = data_lis['file_name'][meta_count:]
new_links = data_lis['link'][meta_count:]
return new_files,new_links
def merge_text(text_list):
i = 0;j = 1
k = len(text_list)
while j < k:
if len(text_list[i].split()) <= 5:
text_list[j] = text_list[i] + " " + text_list[j]
text_list[i] = " "
i += 1;j += 1
return [accepted for accepted in text_list if accepted != " "]
def make_data(new_files,new_links,local_path):
text = [];links = []
for doc in range(len(new_files)):
sub_text = [];sub_link = []
with open(os.path.join(local_path, new_files[doc]), encoding='utf-8') as f:
for line in f.readlines():
temp_text = re.sub("\\n", "", line)
if temp_text != "":
sub_text.append(temp_text)
sub_text = merge_text(sub_text)
sub_link = [new_links[doc] for i in range(len(sub_text))]
text.extend(sub_text)
links.extend(sub_link)
return text,links
def get_final_data():
#Define all the paths
meta_path = "meta_data.json"
data_lis_path = "data_url.csv"
local_path = "Data_final"
data_path = "Responses.csv"
corpus_path = "corpus.pt"
# Load the list of data files
data_lis = pd.read_csv(data_lis_path)
# Load the responses.csv file
if not(os.path.exists(data_path)):
fresh_text = []
fresh_link = []
fresh_data = {
"text": fresh_text,
"links": fresh_link
}
fresh_data = pd.DataFrame(fresh_data)
fresh_data.to_csv(data_path)
data = pd.read_csv(data_path)
# Check for any new files; If present add those to responses.csv file
# Make changes to corpus.pt accordingly
act_count = len(data_lis['file_name'])
with open(meta_path, "r") as jsonFile:
meta_data = json.load(jsonFile)
meta_count = meta_data["data"]["count"]
if meta_count!=act_count:
meta_data["data"]["count"] = act_count
with open(meta_path, "w") as jsonFile:
json.dump(meta_data, jsonFile)
new_files,new_links = collect_data(data_lis,meta_count)
text,links = make_data(new_files,new_links,local_path)
df = {
"text": text,
"links":links
}
df = pd.DataFrame(df)
data = pd.concat([data,df])
data.to_csv("Responses.csv")
if not(os.path.exists(corpus_path)):
corpus_embeddings = bi_encoder.encode(data["text"], convert_to_tensor=True, show_progress_bar=True)
torch.save(corpus_embeddings, corpus_path)
else:
corpus_embeddings = torch.load(corpus_path)
new_embeddings = bi_encoder.encode(df["text"], convert_to_tensor=True, show_progress_bar=True)
corpus_embeddings = torch.cat((corpus_embeddings,new_embeddings),0)
torch.save(corpus_embeddings, corpus_path)
corpus_embeddings = torch.load(corpus_path)
return corpus_embeddings,data
def search(query):
corpus_embeddings,data = get_final_data()
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
top_k = 20
#be = time.process_time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
#print("Time taken by Bi-encoder:" + str(time.process_time() - be))
hits = hits[0]
cross_inp = [[query, data['text'][hit['corpus_id']]] for hit in hits]
#ce = time.process_time()
cross_scores = cross_encoder.predict(cross_inp)
#print("Time taken by Cross-encoder:" + str(time.process_time() - ce))
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
result_table = list()
for hit in hits[0:5]:
ans = "{}".format(data['text'][hit['corpus_id']].replace("\n", " "))
#print(ans)
cs = "{}".format(hit['cross-score'])
#print(cs)
sc = "{}".format(hit['score'])
#print(sc)
corr_link = "{}".format(data['links'][hit['corpus_id']])
wrapper = textwrap.TextWrapper(width=50)
ans = wrapper.fill(text=ans)
result_table.append([ans,str(cs),str(sc),str(corr_link)])
return result_table