tosanoob commited on
Commit
52a8549
1 Parent(s): 38f06a6

Update chat/arxiv_bot/arxiv_bot_utils.py

Browse files
Files changed (1) hide show
  1. chat/arxiv_bot/arxiv_bot_utils.py +299 -296
chat/arxiv_bot/arxiv_bot_utils.py CHANGED
@@ -1,297 +1,300 @@
1
- import chromadb
2
- from chromadb import Documents, EmbeddingFunction, Embeddings
3
- from transformers import AutoModel
4
- import json
5
- from numpy.linalg import norm
6
- import sqlite3
7
- import urllib.request
8
- from django.conf import settings
9
- import Levenshtein
10
-
11
- # this module act as a singleton class
12
-
13
- class JinaAIEmbeddingFunction(EmbeddingFunction):
14
- def __init__(self, model):
15
- super().__init__()
16
- self.model = model
17
-
18
- def __call__(self, input: Documents) -> Embeddings:
19
- embeddings = self.model.encode(input)
20
- return embeddings.tolist()
21
-
22
- # instance of embedding_model
23
- embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en',
24
- trust_remote_code=True,
25
- cache_dir='models')
26
-
27
- # instance of JinaAIEmbeddingFunction
28
- ef = JinaAIEmbeddingFunction(embedding_model)
29
-
30
- # list of topics
31
- topic_descriptions = json.load(open("topic_descriptions.txt"))
32
- topics = list(dict.keys(topic_descriptions))
33
- embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions]
34
- cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
35
-
36
- def lev_sim(a,b): return Levenshtein.distance(a,b)
37
-
38
- def choose_topic(summary):
39
- embed = embedding_model.encode(summary)
40
- topic = ""
41
- max_sim = 0.
42
- for i,key in enumerate(topics):
43
- sim = cos_sim(embed,embeddings[i])
44
- if sim > max_sim:
45
- topic = key
46
- max_sim = sim
47
- return topic
48
-
49
- def authors_list_to_str(authors):
50
- """input a list of authors, return a string represent authors"""
51
- text = ""
52
- for author in authors:
53
- text+=author+", "
54
- return text[:-3]
55
-
56
- def authors_str_to_list(string):
57
- """input a string of authors, return a list of authors"""
58
- authors = []
59
- list_auth = string.split("and")
60
- for author in list_auth:
61
- if author != "et al.":
62
- authors.append(author.strip())
63
- return authors
64
-
65
- def chunk_texts(text, max_char=400):
66
- """
67
- Chunk a long text into several chunks, with each chunk about 300-400 characters long,
68
- but make sure no word is cut in half.
69
- Args:
70
- text: The long text to be chunked.
71
- max_char: The maximum number of characters per chunk (default: 400).
72
- Returns:
73
- A list of chunks.
74
- """
75
- chunks = []
76
- current_chunk = ""
77
- words = text.split()
78
- for word in words:
79
- if len(current_chunk) + len(word) + 1 >= max_char:
80
- chunks.append(current_chunk)
81
- current_chunk = " "
82
- else:
83
- current_chunk += " " + word
84
- chunks.append(current_chunk.strip())
85
- return chunks
86
-
87
- def trimming(txt):
88
- start = txt.find("{")
89
- end = txt.rfind("}")
90
- return txt[start:end+1].replace("\n"," ")
91
-
92
- # crawl data
93
-
94
- def extract_tag(txt,tagname):
95
- return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")]
96
-
97
- def get_record(extract):
98
- id = extract_tag(extract,"id")
99
- updated = extract_tag(extract,"updated")
100
- published = extract_tag(extract,"published")
101
- title = extract_tag(extract,"title").replace("\n ","").strip()
102
- summary = extract_tag(extract,"summary").replace("\n","").strip()
103
- authors = []
104
- while extract.find("<author>")!=-1:
105
- author = extract_tag(extract,"name")
106
- extract = extract[extract.find("</author>")+9:]
107
- authors.append(author)
108
- pattern = '<link title="pdf" href="'
109
- link_start = extract.find('<link title="pdf" href="')
110
- link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2]
111
- return [id, updated, published, title, authors, link, summary]
112
-
113
- def crawl_exact_paper(title,author,max_results=3):
114
- authors = authors_list_to_str(author)
115
- records = []
116
- url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results)
117
- url = url.replace(" ","%20")
118
- try:
119
- arxiv_page = urllib.request.urlopen(url,timeout=100).read()
120
- xml = str(arxiv_page,encoding="utf-8")
121
- while xml.find("<entry>") != -1:
122
- extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
123
- xml = xml[xml.find("</entry>")+8:]
124
- extract = get_record(extract)
125
- topic = choose_topic(extract[6])
126
- records.append([topic,*extract])
127
- return records
128
- except Exception as e:
129
- return "Error: "+str(e)
130
-
131
- def crawl_arxiv(keyword_list, max_results=100):
132
- baseurl = 'http://export.arxiv.org/api/query?search_query='
133
- records = []
134
- for i,keyword in enumerate(keyword_list):
135
- if i ==0:
136
- url = baseurl + 'all:' + keyword
137
- else:
138
- url = url + '+OR+' + 'all:' + keyword
139
- url = url+ '&max_results=' + str(max_results)
140
- url = url.replace(' ', '%20')
141
- try:
142
- arxiv_page = urllib.request.urlopen(url,timeout=100).read()
143
- xml = str(arxiv_page,encoding="utf-8")
144
- while xml.find("<entry>") != -1:
145
- extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
146
- xml = xml[xml.find("</entry>")+8:]
147
- extract = get_record(extract)
148
- topic = choose_topic(extract[6])
149
- records.append([topic,*extract])
150
- return records
151
- except Exception as e:
152
- return "Error: "+str(e)
153
-
154
- # This class act as a module
155
- class ArxivChroma:
156
- """
157
- Create an interface to arxivdb, which only support query and addition.
158
- This interface do not support edition and deletion procedures.
159
- """
160
- client = None
161
- model = None
162
- collection = None
163
-
164
- @staticmethod
165
- def connect(table="arxiv_records", name="arxivdb/"):
166
- ArxivChroma.client = chromadb.PersistentClient(name)
167
- ArxivChroma.model = embedding_model
168
- ArxivChroma.collection = ArxivChroma.client.get_or_create_collection(table,
169
- embedding_function=JinaAIEmbeddingFunction(
170
- model = ArxivChroma.model
171
- ))
172
-
173
- @staticmethod
174
- def query_relevant(keywords, query_texts, n_results=3):
175
- """
176
- Perform a query using a list of keywords (str),
177
- or using a relavant string
178
- """
179
- contains = []
180
- for keyword in keywords:
181
- contains.append({"$contains":keyword.lower()})
182
- return ArxivChroma.collection.query(
183
- query_texts=query_texts,
184
- where_document={
185
- "$or":contains
186
- },
187
- n_results=n_results,
188
- )
189
-
190
- @staticmethod
191
- def query_exact(id):
192
- ids = ["{}_{}".format(id,j) for j in range(0,10)]
193
- return ArxivChroma.collection.get(ids=ids)
194
-
195
- @staticmethod
196
- def add(crawl_records):
197
- """
198
- Add crawl_records (list) obtained from arxiv_crawlers
199
- A record is a list of 8 columns:
200
- [topic, id, updated, published, title, author, link, summary]
201
- Return the final length of the database table
202
- """
203
- for record in crawl_records:
204
- embed_text = """
205
- Topic: {},
206
- Title: {},
207
- Summary: {}
208
- """.format(record[0],record[4],record[7])
209
- chunks = chunk_texts(embed_text)
210
- ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))]
211
- paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))]
212
- ArxivChroma.collection.add(
213
- documents = chunks,
214
- metadatas=paper_ids,
215
- ids = ids
216
- )
217
- return ArxivChroma.collection.count()
218
-
219
- @staticmethod
220
- def close_connection():
221
- pass
222
-
223
- # This class act as a module
224
- class ArxivSQL:
225
- table = "arxivsql"
226
- con = None
227
- cur = None
228
-
229
- @staticmethod
230
- def connect(name="db.sqlite3"):
231
- ArxivSQL.con = sqlite3.connect(name, check_same_thread=False)
232
- ArxivSQL.cur = ArxivSQL.con.cursor()
233
-
234
- @staticmethod
235
- def query(title="", author=[], threshold = 15):
236
- if len(author)>0:
237
- query_author= " OR ".join([f"authors LIKE '%{a}%'" for a in author])
238
- else:
239
- query_author= "True"
240
- # Execute the query
241
- query = f"select * from {ArxivSQL.table} where {query_author}"
242
- results = ArxivSQL.cur.execute(query).fetchall()
243
- if len(title) == 0:
244
- return results
245
- else:
246
- sim_score = {}
247
- for row in results:
248
- row_title = row[2]
249
- row_id = row[0]
250
- score = lev_sim(title, row_title)
251
- if score < threshold:
252
- sim_score[row_id] = score
253
- sorted_results = sorted(sim_score.items(), key=lambda x: x[1])
254
- return ArxivSQL.query_id(sorted_results)
255
-
256
- @staticmethod
257
- def query_id(ids=[]):
258
- try:
259
- if len(ids) == 0:
260
- return None
261
- query = "select * from {} where id in (".format(ArxivSQL.table)
262
- for id in ids:
263
- query+="'"+id+"',"
264
- query = query[:-1] + ")"
265
- result = ArxivSQL.cur.execute(query)
266
- return result.fetchall()
267
- except Exception as e:
268
- print(e)
269
- print("Error query: ",query)
270
-
271
- @staticmethod
272
- def add(crawl_records):
273
- """
274
- Add crawl_records (list) obtained from arxiv_crawlers
275
- A record is a list of 8 columns:
276
- [topic, id, updated, published, title, author, link, summary]
277
- Return the final length of the database table
278
- """
279
- results = ""
280
- for record in crawl_records:
281
- try:
282
- query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format(
283
- record[1][21:],
284
- record[0],
285
- record[4].replace('"',"'"),
286
- authors_list_to_str(record[5]),
287
- record[2][:10],
288
- record[3][:10],
289
- record[6]
290
- )
291
- ArxivSQL.cur.execute(query)
292
- ArxivSQL.con.commit()
293
- except Exception as e:
294
- results+=str(e)
295
- results+="\n" + query + "\n"
296
- finally:
 
 
 
297
  return results
 
1
+ import chromadb
2
+ from chromadb import Documents, EmbeddingFunction, Embeddings
3
+ from transformers import AutoModel
4
+ import json
5
+ from numpy.linalg import norm
6
+ import sqlite3
7
+ import urllib.request
8
+ from django.conf import settings
9
+ import Levenshtein
10
+
11
+ # this module act as a singleton class
12
+
13
+ import os
14
+ os.environ['HF_HOME'] = 'models/'
15
+
16
+ class JinaAIEmbeddingFunction(EmbeddingFunction):
17
+ def __init__(self, model):
18
+ super().__init__()
19
+ self.model = model
20
+
21
+ def __call__(self, input: Documents) -> Embeddings:
22
+ embeddings = self.model.encode(input)
23
+ return embeddings.tolist()
24
+
25
+ # instance of embedding_model
26
+ embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en',
27
+ trust_remote_code=True,
28
+ cache_dir='models')
29
+
30
+ # instance of JinaAIEmbeddingFunction
31
+ ef = JinaAIEmbeddingFunction(embedding_model)
32
+
33
+ # list of topics
34
+ topic_descriptions = json.load(open("topic_descriptions.txt"))
35
+ topics = list(dict.keys(topic_descriptions))
36
+ embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions]
37
+ cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
38
+
39
+ def lev_sim(a,b): return Levenshtein.distance(a,b)
40
+
41
+ def choose_topic(summary):
42
+ embed = embedding_model.encode(summary)
43
+ topic = ""
44
+ max_sim = 0.
45
+ for i,key in enumerate(topics):
46
+ sim = cos_sim(embed,embeddings[i])
47
+ if sim > max_sim:
48
+ topic = key
49
+ max_sim = sim
50
+ return topic
51
+
52
+ def authors_list_to_str(authors):
53
+ """input a list of authors, return a string represent authors"""
54
+ text = ""
55
+ for author in authors:
56
+ text+=author+", "
57
+ return text[:-3]
58
+
59
+ def authors_str_to_list(string):
60
+ """input a string of authors, return a list of authors"""
61
+ authors = []
62
+ list_auth = string.split("and")
63
+ for author in list_auth:
64
+ if author != "et al.":
65
+ authors.append(author.strip())
66
+ return authors
67
+
68
+ def chunk_texts(text, max_char=400):
69
+ """
70
+ Chunk a long text into several chunks, with each chunk about 300-400 characters long,
71
+ but make sure no word is cut in half.
72
+ Args:
73
+ text: The long text to be chunked.
74
+ max_char: The maximum number of characters per chunk (default: 400).
75
+ Returns:
76
+ A list of chunks.
77
+ """
78
+ chunks = []
79
+ current_chunk = ""
80
+ words = text.split()
81
+ for word in words:
82
+ if len(current_chunk) + len(word) + 1 >= max_char:
83
+ chunks.append(current_chunk)
84
+ current_chunk = " "
85
+ else:
86
+ current_chunk += " " + word
87
+ chunks.append(current_chunk.strip())
88
+ return chunks
89
+
90
+ def trimming(txt):
91
+ start = txt.find("{")
92
+ end = txt.rfind("}")
93
+ return txt[start:end+1].replace("\n"," ")
94
+
95
+ # crawl data
96
+
97
+ def extract_tag(txt,tagname):
98
+ return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")]
99
+
100
+ def get_record(extract):
101
+ id = extract_tag(extract,"id")
102
+ updated = extract_tag(extract,"updated")
103
+ published = extract_tag(extract,"published")
104
+ title = extract_tag(extract,"title").replace("\n ","").strip()
105
+ summary = extract_tag(extract,"summary").replace("\n","").strip()
106
+ authors = []
107
+ while extract.find("<author>")!=-1:
108
+ author = extract_tag(extract,"name")
109
+ extract = extract[extract.find("</author>")+9:]
110
+ authors.append(author)
111
+ pattern = '<link title="pdf" href="'
112
+ link_start = extract.find('<link title="pdf" href="')
113
+ link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2]
114
+ return [id, updated, published, title, authors, link, summary]
115
+
116
+ def crawl_exact_paper(title,author,max_results=3):
117
+ authors = authors_list_to_str(author)
118
+ records = []
119
+ url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results)
120
+ url = url.replace(" ","%20")
121
+ try:
122
+ arxiv_page = urllib.request.urlopen(url,timeout=100).read()
123
+ xml = str(arxiv_page,encoding="utf-8")
124
+ while xml.find("<entry>") != -1:
125
+ extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
126
+ xml = xml[xml.find("</entry>")+8:]
127
+ extract = get_record(extract)
128
+ topic = choose_topic(extract[6])
129
+ records.append([topic,*extract])
130
+ return records
131
+ except Exception as e:
132
+ return "Error: "+str(e)
133
+
134
+ def crawl_arxiv(keyword_list, max_results=100):
135
+ baseurl = 'http://export.arxiv.org/api/query?search_query='
136
+ records = []
137
+ for i,keyword in enumerate(keyword_list):
138
+ if i ==0:
139
+ url = baseurl + 'all:' + keyword
140
+ else:
141
+ url = url + '+OR+' + 'all:' + keyword
142
+ url = url+ '&max_results=' + str(max_results)
143
+ url = url.replace(' ', '%20')
144
+ try:
145
+ arxiv_page = urllib.request.urlopen(url,timeout=100).read()
146
+ xml = str(arxiv_page,encoding="utf-8")
147
+ while xml.find("<entry>") != -1:
148
+ extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
149
+ xml = xml[xml.find("</entry>")+8:]
150
+ extract = get_record(extract)
151
+ topic = choose_topic(extract[6])
152
+ records.append([topic,*extract])
153
+ return records
154
+ except Exception as e:
155
+ return "Error: "+str(e)
156
+
157
+ # This class act as a module
158
+ class ArxivChroma:
159
+ """
160
+ Create an interface to arxivdb, which only support query and addition.
161
+ This interface do not support edition and deletion procedures.
162
+ """
163
+ client = None
164
+ model = None
165
+ collection = None
166
+
167
+ @staticmethod
168
+ def connect(table="arxiv_records", name="arxivdb/"):
169
+ ArxivChroma.client = chromadb.PersistentClient(name)
170
+ ArxivChroma.model = embedding_model
171
+ ArxivChroma.collection = ArxivChroma.client.get_or_create_collection(table,
172
+ embedding_function=JinaAIEmbeddingFunction(
173
+ model = ArxivChroma.model
174
+ ))
175
+
176
+ @staticmethod
177
+ def query_relevant(keywords, query_texts, n_results=3):
178
+ """
179
+ Perform a query using a list of keywords (str),
180
+ or using a relavant string
181
+ """
182
+ contains = []
183
+ for keyword in keywords:
184
+ contains.append({"$contains":keyword.lower()})
185
+ return ArxivChroma.collection.query(
186
+ query_texts=query_texts,
187
+ where_document={
188
+ "$or":contains
189
+ },
190
+ n_results=n_results,
191
+ )
192
+
193
+ @staticmethod
194
+ def query_exact(id):
195
+ ids = ["{}_{}".format(id,j) for j in range(0,10)]
196
+ return ArxivChroma.collection.get(ids=ids)
197
+
198
+ @staticmethod
199
+ def add(crawl_records):
200
+ """
201
+ Add crawl_records (list) obtained from arxiv_crawlers
202
+ A record is a list of 8 columns:
203
+ [topic, id, updated, published, title, author, link, summary]
204
+ Return the final length of the database table
205
+ """
206
+ for record in crawl_records:
207
+ embed_text = """
208
+ Topic: {},
209
+ Title: {},
210
+ Summary: {}
211
+ """.format(record[0],record[4],record[7])
212
+ chunks = chunk_texts(embed_text)
213
+ ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))]
214
+ paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))]
215
+ ArxivChroma.collection.add(
216
+ documents = chunks,
217
+ metadatas=paper_ids,
218
+ ids = ids
219
+ )
220
+ return ArxivChroma.collection.count()
221
+
222
+ @staticmethod
223
+ def close_connection():
224
+ pass
225
+
226
+ # This class act as a module
227
+ class ArxivSQL:
228
+ table = "arxivsql"
229
+ con = None
230
+ cur = None
231
+
232
+ @staticmethod
233
+ def connect(name="db.sqlite3"):
234
+ ArxivSQL.con = sqlite3.connect(name, check_same_thread=False)
235
+ ArxivSQL.cur = ArxivSQL.con.cursor()
236
+
237
+ @staticmethod
238
+ def query(title="", author=[], threshold = 15):
239
+ if len(author)>0:
240
+ query_author= " OR ".join([f"authors LIKE '%{a}%'" for a in author])
241
+ else:
242
+ query_author= "True"
243
+ # Execute the query
244
+ query = f"select * from {ArxivSQL.table} where {query_author}"
245
+ results = ArxivSQL.cur.execute(query).fetchall()
246
+ if len(title) == 0:
247
+ return results
248
+ else:
249
+ sim_score = {}
250
+ for row in results:
251
+ row_title = row[2]
252
+ row_id = row[0]
253
+ score = lev_sim(title, row_title)
254
+ if score < threshold:
255
+ sim_score[row_id] = score
256
+ sorted_results = sorted(sim_score.items(), key=lambda x: x[1])
257
+ return ArxivSQL.query_id(sorted_results)
258
+
259
+ @staticmethod
260
+ def query_id(ids=[]):
261
+ try:
262
+ if len(ids) == 0:
263
+ return None
264
+ query = "select * from {} where id in (".format(ArxivSQL.table)
265
+ for id in ids:
266
+ query+="'"+id+"',"
267
+ query = query[:-1] + ")"
268
+ result = ArxivSQL.cur.execute(query)
269
+ return result.fetchall()
270
+ except Exception as e:
271
+ print(e)
272
+ print("Error query: ",query)
273
+
274
+ @staticmethod
275
+ def add(crawl_records):
276
+ """
277
+ Add crawl_records (list) obtained from arxiv_crawlers
278
+ A record is a list of 8 columns:
279
+ [topic, id, updated, published, title, author, link, summary]
280
+ Return the final length of the database table
281
+ """
282
+ results = ""
283
+ for record in crawl_records:
284
+ try:
285
+ query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format(
286
+ record[1][21:],
287
+ record[0],
288
+ record[4].replace('"',"'"),
289
+ authors_list_to_str(record[5]),
290
+ record[2][:10],
291
+ record[3][:10],
292
+ record[6]
293
+ )
294
+ ArxivSQL.cur.execute(query)
295
+ ArxivSQL.con.commit()
296
+ except Exception as e:
297
+ results+=str(e)
298
+ results+="\n" + query + "\n"
299
+ finally:
300
  return results