File size: 3,117 Bytes
d0f7013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3dbd3c
d0f7013
 
 
 
 
b3dbd3c
d0f7013
b3dbd3c
 
 
 
d0f7013
b3dbd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0f7013
 
b3dbd3c
 
d0f7013
 
 
 
 
 
 
b3dbd3c
d0f7013
 
b3dbd3c
 
d0f7013
b3dbd3c
d0f7013
 
 
 
 
 
5733770
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import inspect
import os
import json
from io import BytesIO
from typing import List, Type

from flask import Flask, jsonify, render_template, request, send_file
from flask_restx import Resource, Api, fields
from funcs import emb_wiki, emb_arxiv, WikiKnowledgeBase, ArXivKnowledgeBase

app = Flask(__name__)
api = Api(
    app,
    version="0.1",
    terms_url="https://myscale.com/terms/",
    contact_email="support@myscale.com",
    title="MyScale Open Knowledge Base",
    description="An API to get relevant page from MyScale Open Knowledge Base",
)

query_result = api.model(
    "QueryResult",
    {
        "documents": fields.String,
        "num_retrieved": fields.Integer,
    },
)

kb_list = {
    "wiki": lambda: WikiKnowledgeBase(embedding=emb_wiki),
    "arxiv": lambda: ArXivKnowledgeBase(embedding=emb_arxiv),
}

query_parser = api.parser()
query_parser.add_argument(
    "subject",
    required=True,
    type=str,
    help="a sentence or phrase describes the subject you want to query.",
)
query_parser.add_argument(
    "where_str", required=True, type=str, help="a sql-like where string to build filter"
)
query_parser.add_argument(
    "limit", required=False, type=int, default=4, help="desired number of retrieved documents"
)


@api.route(
    "/get_related_arxiv",
    doc={
        "description": (
            "Get some related papers.\nYou should use schema here:\n\n"
            "CREATE TABLE ArXiv (\n"
            "    `id` String,\n"
            "    `abstract` String, \n"
            "    `pubdate` DateTime, \n"
            "    `title` String, \n"
            "    `categories` Array(String), -- arxiv category\n"
            "    `authors` Array(String),\n"
            "    `comment` String,\n"
            "ORDER BY id\n\n"
        ),
    },
)
class get_related_arxiv(Resource):
    @api.expect(query_parser)
    @api.marshal_with(query_result)
    @api.doc(id='get_related_arxiv')
    def get(self):
        args = query_parser.parse_args()
        kb = kb_list['arxiv']()
        docs, num_docs = kb(args.subject, args.where_str, args.limit)
        return {"documents": docs, "num_retrieved": num_docs}


@api.route(
    "/get_related_wiki",
    doc={
        "description": (
            "Get some related wiki pages.\nYou should use schema here:\n\n"
            "CREATE TABLE Wikipedia (\n"
            "    `id` String,\n"
            "    `text` String,\n"
            "    `title` String,\n"
            "    `view` Float32,\n"
            "    `url` String, -- URL to this wiki page\n"
            "ORDER BY id\n\n"
            "You should avoid using LIKE on long text columns."
        ),
    },
)
class get_related_wiki(Resource):
    @api.expect(query_parser)
    @api.marshal_with(query_result)
    @api.doc(id='get_related_wiki')
    def get(self):
        args = query_parser.parse_args()
        kb = kb_list['wiki']()
        docs, num_docs = kb(args.subject, args.where_str, args.limit)
        return {"documents": docs, "num_retrieved": num_docs}


if __name__ == "__main__":
    # print(json.dumps(api.__schema__))
    app.run(host="0.0.0.0", port=7860)