import gradio as gr import inspect import os import requests import json from io import BytesIO from typing import List, Type from flask import Flask, jsonify, render_template, request, send_file from flask_restx import Resource, Api, fields from funcs import emb_wiki, emb_arxiv, WikiKnowledgeBase, ArXivKnowledgeBase app = Flask(__name__) api = Api( app, version="0.1", terms_url="https://myscale.com/terms/", contact_email="support@myscale.com", title="MyScale Open Knowledge Base", description="An API to get relevant page from MyScale Open Knowledge Base", ) query_result = api.model( "QueryResult", { "documents": fields.String, "num_retrieved": fields.Integer, }, ) kb_list = { "wiki": lambda: WikiKnowledgeBase(embedding=emb_wiki), "arxiv": lambda: ArXivKnowledgeBase(embedding=emb_arxiv), } query_parser = api.parser() query_parser.add_argument( "subject", required=True, type=str, help="a sentence or phrase describes the subject you want to query.", ) query_parser.add_argument( "where_str", required=True, type=str, help="a sql-like where string to build filter" ) query_parser.add_argument( "limit", required=False, type=int, default=4, help="desired number of retrieved documents" ) @api.route( "/get_related_docs/", doc={ "description": ( "Get some related papers.\nYou should use schema here:\n\n" "CREATE TABLE ArXiv (\n" " `id` String,\n" " `abstract` String, -- abstract of the paper. avoid using this column to do LIKE match\n" " `pubdate` DateTime, \n" " `title` String, -- title of the paper\n" " `categories` Array(String), -- arxiv category of the paper\n" " `authors` Array(String), -- authors of the paper\n" " `comment` String, -- extra comments of the paper\n" "ORDER BY id\n\n" "CREATE TABLE Wikipedia (\n" " `id` String,\n" " `text` String, -- abstract of the wiki page. avoid using this column to do LIKE match\n" " `title` String, -- title of the paper\n" " `view` Float32,\n" " `url` String, -- URL to this wiki page\n" "ORDER BY id\n\n" "You should avoid using LIKE on long text columns." ), }, ) @api.param("knowledge_base", "Knowledge base used to query. Must be one of ['wiki', 'arxiv']") class get_related_docs(Resource): @api.expect(query_parser) @api.marshal_with(query_result) def get(self, knowledge_base): args = query_parser.parse_args() kb = kb_list[knowledge_base]() print(kb) print(args.subject, args.where_str, args.limit) docs, num_docs = kb(args.subject, args.where_str, args.limit) return {"documents": docs, "num_retrieved": num_docs} if __name__ == "__main__": # print(json.dumps(api.__schema__)) app.run(host="0.0.0.0", port=7860)