Fangrui Liu
add api
d0f7013
raw
history blame
2.31 kB
from typing import List, Tuple
import clickhouse_connect
from sentence_transformers import SentenceTransformer
from InstructorEmbedding import INSTRUCTOR
emb_wiki = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
emb_arxiv = INSTRUCTOR('hkunlp/instructor-xl')
class ArXivKnowledgeBase:
def __init__(self, embedding: SentenceTransformer) -> None:
self.db = clickhouse_connect.get_client(
host='msc-4a9e710a.us-east-1.aws.staging.myscale.cloud',
port=443,
username='chatdata',
password='myscale_rocks'
)
self.embedding: SentenceTransformer = embedding
self.table: str = 'default.ChatArXiv'
self.embedding_col = "vector"
self.must_have_cols: List[str] = ['id', 'abstract', 'authors', 'categories', 'comment', 'title', 'pubdate']
def __call__(self, subject: str, where_str: str = None, limit: int = 5) -> Tuple[str, int]:
q_emb = self.embedding.encode(subject).tolist()
q_emb_str = ",".join(map(str, q_emb))
if where_str:
where_str = f"WHERE {where_str}"
else:
where_str = ""
q_str = f"""
SELECT dist, {','.join(self.must_have_cols)}
FROM {self.table}
{where_str}
ORDER BY distance({self.embedding_col}, [{q_emb_str}])
AS dist ASC
LIMIT {limit}
"""
docs = [r for r in self.db.query(q_str).named_results()]
return '\n'.join([str(d) for d in docs]), len(docs)
class WikiKnowledgeBase(ArXivKnowledgeBase):
def __init__(self, embedding: SentenceTransformer) -> None:
super().__init__(embedding)
self.table: str = 'wiki.Wikipedia'
self.embedding_col = "emb"
self.must_have_cols: List[str] = ['text', 'title', 'views', 'url']
if __name__ == '__main__':
# kb = ArXivKnowledgeBase(embedding=emb_arxiv)
kb = WikiKnowledgeBase(embedding=emb_wiki)
d = kb("When did Steven Jobs die?", "", 5)
print(d)
d = {"components": {
"schemas": {
"type": "object",
"properties": {
"todos":{
"type": "array",
"items":{"type": "string"},
"description": "The list of todos.",
}
}
}
}
}