Spaces:
Runtime error
Runtime error
from typing import List, Tuple | |
import clickhouse_connect | |
from sentence_transformers import SentenceTransformer | |
from InstructorEmbedding import INSTRUCTOR | |
emb_wiki = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2") | |
emb_arxiv = INSTRUCTOR('hkunlp/instructor-xl') | |
class ArXivKnowledgeBase: | |
def __init__(self, embedding: SentenceTransformer) -> None: | |
self.db = clickhouse_connect.get_client( | |
host='msc-4a9e710a.us-east-1.aws.staging.myscale.cloud', | |
port=443, | |
username='chatdata', | |
password='myscale_rocks' | |
) | |
self.embedding: SentenceTransformer = embedding | |
self.table: str = 'default.ChatArXiv' | |
self.embedding_col = "vector" | |
self.must_have_cols: List[str] = ['id', 'abstract', 'authors', 'categories', 'comment', 'title', 'pubdate'] | |
def __call__(self, subject: str, where_str: str = None, limit: int = 5) -> Tuple[str, int]: | |
q_emb = self.embedding.encode(subject).tolist() | |
q_emb_str = ",".join(map(str, q_emb)) | |
if where_str: | |
where_str = f"WHERE {where_str}" | |
else: | |
where_str = "" | |
q_str = f""" | |
SELECT dist, {','.join(self.must_have_cols)} | |
FROM {self.table} | |
{where_str} | |
ORDER BY distance({self.embedding_col}, [{q_emb_str}]) | |
AS dist ASC | |
LIMIT {limit} | |
""" | |
docs = [r for r in self.db.query(q_str).named_results()] | |
return '\n'.join([str(d) for d in docs]), len(docs) | |
class WikiKnowledgeBase(ArXivKnowledgeBase): | |
def __init__(self, embedding: SentenceTransformer) -> None: | |
super().__init__(embedding) | |
self.table: str = 'wiki.Wikipedia' | |
self.embedding_col = "emb" | |
self.must_have_cols: List[str] = ['text', 'title', 'views', 'url'] | |
if __name__ == '__main__': | |
# kb = ArXivKnowledgeBase(embedding=emb_arxiv) | |
kb = WikiKnowledgeBase(embedding=emb_wiki) | |
d = kb("When did Steven Jobs die?", "", 5) | |
print(d) | |
d = {"components": { | |
"schemas": { | |
"type": "object", | |
"properties": { | |
"todos":{ | |
"type": "array", | |
"items":{"type": "string"}, | |
"description": "The list of todos.", | |
} | |
} | |
} | |
} | |
} |