Pavankalyan commited on
Commit
17283b0
1 Parent(s): 4681ada

Upload 7 files

Browse files
Files changed (7) hide show
  1. Responses.csv +0 -0
  2. app.py +19 -0
  3. corpus.pt +3 -0
  4. data_process.py +44 -0
  5. main.py +22 -0
  6. requirements.txt +2 -0
  7. retrieval.py +69 -0
Responses.csv ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from data_process import *
3
+ from retrieval import *
4
+
5
+ df = pd.read_csv("Responses.csv")
6
+ text = list(df["text"].values)
7
+
8
+
9
+ def chitti(query):
10
+ re_table = search(query, text)
11
+ return re_table[0][0]
12
+
13
+ demo = gr.Interface(
14
+ fn=chitti,
15
+ inputs=["text"],
16
+ outputs=["text"],
17
+ )
18
+ demo.launch(share=True)
19
+
corpus.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90d8781fef8d3a3b5a5130ce095c186c076a05ee25e3980cc3cf2577910302b2
3
+ size 5803755
data_process.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+
5
+
6
+ def merge_text(text_list):
7
+ i = 0
8
+ j = 1
9
+
10
+ k = len(text_list)
11
+
12
+ while j < k:
13
+ if len(text_list[i].split()) <= 30:
14
+ text_list[j] = text_list[i] + " " + text_list[j]
15
+ text_list[i] = " "
16
+ i += 1
17
+ j += 1
18
+
19
+ return [accepted for accepted in text_list if accepted is not " "]
20
+
21
+
22
+ def get_text(path):
23
+ doc_list = sorted(os.listdir(path))
24
+ text = []
25
+ for doc in doc_list:
26
+ sub_text = []
27
+ with open(os.path.join(path, doc), encoding='utf-8') as f:
28
+ for line in f.readlines():
29
+ temp_text = re.sub("\\n", "", line)
30
+ if temp_text != "":
31
+ sub_text.append(temp_text)
32
+
33
+ sub_text = merge_text(sub_text)
34
+ text.extend(sub_text)
35
+ return text
36
+
37
+
38
+ def dataframe(path):
39
+ text = get_text(path)
40
+ df = {
41
+ "text": text
42
+ }
43
+ df = pd.DataFrame(df)
44
+ df.to_csv("Responses.csv")
main.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_process import *
2
+ from retrieval import *
3
+ import argparse
4
+
5
+
6
+ parser = argparse.ArgumentParser(description="Run the query for the bot")
7
+ parser.add_argument('--query', help="Question to the bot", type=str, required=True)
8
+ parser.add_argument('--data_path', help="Path for the stored dataset", type=str, required=True)
9
+
10
+ args = parser.parse_args()
11
+ path = args.data_path
12
+ query = args.query
13
+
14
+ if "Responses.csv" not in os.listdir(os.getcwd()):
15
+ dataframe(path)
16
+
17
+ df = pd.read_csv("Responses.csv")
18
+ text = list(df["text"].values)
19
+
20
+
21
+ search(query, text)
22
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sentence-transformers
2
+ pandas
retrieval.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import textwrap
3
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
4
+ import torch
5
+ from tabulate import tabulate
6
+ import time
7
+
8
+ model_bi_encoder = "msmarco-distilbert-base-tas-b"
9
+ model_cross_encoder = "cross-encoder/ms-marco-MiniLM-L-12-v2"
10
+
11
+ bi_encoder = SentenceTransformer(model_bi_encoder)
12
+ bi_encoder.max_seq_length = 512
13
+
14
+ cross_encoder = CrossEncoder(model_cross_encoder)
15
+
16
+ top_k = 20
17
+
18
+
19
+ def get_corpus(passages):
20
+
21
+ if "corpus.pt" not in os.listdir(os.getcwd()):
22
+ corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
23
+ torch.save(corpus_embeddings, "corpus.pt")
24
+ else:
25
+ corpus_embeddings = torch.load("corpus.pt")
26
+
27
+ return corpus_embeddings
28
+
29
+
30
+ def search(query, passages):
31
+
32
+ corpus_embeddings = get_corpus(passages)
33
+ question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
34
+
35
+ be = time.process_time()
36
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
37
+ #print("Time taken by Bi-encoder:" + str(time.process_time() - be))
38
+
39
+ hits = hits[0]
40
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
41
+
42
+ ce = time.process_time()
43
+ cross_scores = cross_encoder.predict(cross_inp)
44
+ #print("Time taken by Cross-encoder:" + str(time.process_time() - ce))
45
+
46
+ # Sort results by the cross-encoder scores
47
+ for idx in range(len(cross_scores)):
48
+ hits[idx]['cross-score'] = cross_scores[idx]
49
+
50
+
51
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
52
+ result_table = list()
53
+ for hit in hits[0:5]:
54
+ ans = "{}".format(passages[hit['corpus_id']].replace("\n", " "))
55
+ #print(ans)
56
+ cs = "{}".format(hit['cross-score'])
57
+ #print(cs)
58
+ sc = "{}".format(hit['score'])
59
+ #print(sc)
60
+ wrapper = textwrap.TextWrapper(width=50)
61
+ ans = wrapper.fill(text=ans)
62
+ result_table.append([ans,str(cs),str(sc)])
63
+
64
+ return result_table
65
+
66
+ #print(tabulate(result_table, headers=["Answer", "Cross-encoder score", "Bi-encoder score"], tablefmt="fancy_grid", maxcolwidths=[None, None, None]))
67
+
68
+
69
+