File size: 2,313 Bytes
d0f7013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

from typing import List, Tuple
import clickhouse_connect
from sentence_transformers import SentenceTransformer
from InstructorEmbedding import INSTRUCTOR


emb_wiki = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
emb_arxiv = INSTRUCTOR('hkunlp/instructor-xl')

class ArXivKnowledgeBase:
    def __init__(self, embedding: SentenceTransformer) -> None:    
        self.db = clickhouse_connect.get_client(
            host='msc-4a9e710a.us-east-1.aws.staging.myscale.cloud',
            port=443,
            username='chatdata',
            password='myscale_rocks'
        )
        self.embedding: SentenceTransformer = embedding
        self.table: str = 'default.ChatArXiv'
        self.embedding_col = "vector"
        self.must_have_cols: List[str] = ['id', 'abstract', 'authors', 'categories', 'comment', 'title', 'pubdate']
    

    def __call__(self, subject: str, where_str: str = None, limit: int = 5) -> Tuple[str, int]:
        q_emb = self.embedding.encode(subject).tolist()
        q_emb_str = ",".join(map(str, q_emb))
        if where_str:
            where_str = f"WHERE {where_str}"
        else:
            where_str = ""

        q_str = f"""
            SELECT dist, {','.join(self.must_have_cols)}
            FROM {self.table}
            {where_str}
            ORDER BY distance({self.embedding_col}, [{q_emb_str}]) 
                AS dist ASC
            LIMIT {limit}
            """

        docs = [r for r in self.db.query(q_str).named_results()]
        return '\n'.join([str(d) for d in docs]), len(docs)
    
class WikiKnowledgeBase(ArXivKnowledgeBase):
    def __init__(self, embedding: SentenceTransformer) -> None:    
        super().__init__(embedding)
        self.table: str = 'wiki.Wikipedia'
        self.embedding_col = "emb"
        self.must_have_cols: List[str] = ['text', 'title', 'views', 'url']


if __name__ == '__main__':
    # kb = ArXivKnowledgeBase(embedding=emb_arxiv)
    kb = WikiKnowledgeBase(embedding=emb_wiki)
    d = kb("When did Steven Jobs die?", "", 5)
    print(d)


d = {"components": {
    "schemas": {
    "type": "object",
    "properties": {
        "todos":{
          "type": "array",
          "items":{"type": "string"},
          "description": "The list of todos.",
         }
    }
   }
  }
  }