import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from transformers import AutoModel
import json
from numpy.linalg import norm
import sqlite3
import urllib
class JinaAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, model):
super().__init__()
self.model = model
def __call__(self, input: Documents) -> Embeddings:
embeddings = self.model.encode(input)
return embeddings.tolist()
class ArxivSQL:
def __init__(self, table="arxivsql", name="arxiv_records_sql"):
self.con = sqlite3.connect(name)
self.cur = self.con.cursor()
self.table = table
def query(self, title="", author=[]):
if len(title)>0:
query_title = 'title like "%{}%"'.format(title)
else:
query_title = "True"
if len(author)>0:
query_author = 'author like '
for auth in author:
query_author += "'%{}%' or ".format(auth)
query_author = query_author[:-4]
else:
query_author = "True"
query = "select * from {} where {} and {}".format(self.table,query_title,query_author)
result = self.cur.execute(query)
return result.fetchall()
def query_id(self, ids=[]):
query = "select * from {} where id in (".format(self.table)
for id in ids:
query+="'"+id+"',"
query = query[:-1] + ")"
result = self.cur.execute(query)
return result.fetchall()
def add(self, crawl_records):
"""
Add crawl_records (list) obtained from arxiv_crawlers
A record is a list of 8 columns:
[topic, id, updated, published, title, author, link, summary]
Return the final length of the database table
"""
results = ""
for record in crawl_records:
try:
query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format(
record[1][21:],
record[0],
record[4].replace('"',"'"),
process_authors_str(record[5]),
record[2][:10],
record[3][:10],
record[6]
)
self.cur.execute(query)
self.con.commit()
except Exception as e:
result+=str(e)
result+="\n" + query + "\n"
finally:
return results
class ArxivChroma:
"""
Create an interface to arxivdb, which only support query and addition.
This interface do not support edition and deletion procedures.
"""
def __init__(self, table="arxiv_records", name="arxivdb/"):
self.client = chromadb.PersistentClient(name)
self.model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en',
trust_remote_code=True,
cache_dir='models')
self.collection = self.client.get_or_create_collection(table,
embedding_function=JinaAIEmbeddingFunction(
model = self.model
))
def query_relevant(self, keywords, query_texts, n_results=3):
"""
Perform a query using a list of keywords (str),
or using a relavant string
"""
contains = []
for keyword in keywords:
contains.append({"$contains":keyword})
return self.collection.query(
query_texts=query_texts,
where_document={
"$or":contains
},
n_results=n_results,
)
def query_exact(self, id):
ids = ["{}_{}".format(id,j) for j in range(0,10)]
return self.collection.get(ids=ids)
def add(self, crawl_records):
"""
Add crawl_records (list) obtained from arxiv_crawlers
A record is a list of 8 columns:
[topic, id, updated, published, title, author, link, summary]
Return the final length of the database table
"""
for record in crawl_records:
embed_text = """
Topic: {},
Title: {},
Summary: {}
""".format(record[0],record[4],record[7])
chunks = chunk_text_with_overlap(embed_text)
ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))]
paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))]
self.collection.add(
documents = chunks,
metadatas=paper_ids,
ids = ids
)
return self.collection.count()
def chunk_text_with_overlap(text, max_char=400, overlap=100):
"""
Chunk a long text into several chunks, with each chunk about 300-400 characters long,
but make sure no word is cut in half. It also ensures an overlap of a specified length
between consecutive chunks.
Args:
text: The long text to be chunked.
max_char: The maximum number of characters per chunk (default: 400).
overlap: The desired overlap between consecutive chunks (default: 70).
Returns:
A list of chunks.
"""
chunks = []
current_chunk = ""
words = text.split()
for word in words:
# Check if adding the word would exceed the chunk limit (including overlap)
if len(current_chunk) + len(word) + 1 >= max_char:
chunks.append(current_chunk)
split_point = current_chunk.find(" ",len(current_chunk)-overlap)
current_chunk = current_chunk[split_point:] + " " + word
else:
current_chunk += " " + word
# Add the last chunk (including potential overlap)
chunks.append(current_chunk.strip())
return chunks
def trimming(txt):
start = txt.find("{")
end = txt.rfind("}")
return txt[start:end+1]
def extract_tag(txt,tagname):
return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find(""+tagname+">")]
def get_record(extract):
# id = extract[extract.find("