from typing import List, Tuple import clickhouse_connect from sentence_transformers import SentenceTransformer from InstructorEmbedding import INSTRUCTOR emb_wiki = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2") emb_arxiv = INSTRUCTOR('hkunlp/instructor-xl') class ArXivKnowledgeBase: def __init__(self, embedding: SentenceTransformer) -> None: self.db = clickhouse_connect.get_client( host='msc-4a9e710a.us-east-1.aws.staging.myscale.cloud', port=443, username='chatdata', password='myscale_rocks' ) self.embedding: SentenceTransformer = embedding self.table: str = 'default.ChatArXiv' self.embedding_col = "vector" self.must_have_cols: List[str] = ['id', 'abstract', 'authors', 'categories', 'comment', 'title', 'pubdate'] def __call__(self, subject: str, where_str: str = None, limit: int = 5) -> Tuple[str, int]: q_emb = self.embedding.encode(subject).tolist() q_emb_str = ",".join(map(str, q_emb)) if where_str: where_str = f"WHERE {where_str}" else: where_str = "" q_str = f""" SELECT dist, {','.join(self.must_have_cols)} FROM {self.table} {where_str} ORDER BY distance({self.embedding_col}, [{q_emb_str}]) AS dist ASC LIMIT {limit} """ docs = [r for r in self.db.query(q_str).named_results()] return '\n'.join([str(d) for d in docs]), len(docs) class WikiKnowledgeBase(ArXivKnowledgeBase): def __init__(self, embedding: SentenceTransformer) -> None: super().__init__(embedding) self.table: str = 'wiki.Wikipedia' self.embedding_col = "emb" self.must_have_cols: List[str] = ['text', 'title', 'views', 'url'] if __name__ == '__main__': # kb = ArXivKnowledgeBase(embedding=emb_arxiv) kb = WikiKnowledgeBase(embedding=emb_wiki) d = kb("When did Steven Jobs die?", "", 5) print(d) d = {"components": { "schemas": { "type": "object", "properties": { "todos":{ "type": "array", "items":{"type": "string"}, "description": "The list of todos.", } } } } }