File size: 8,455 Bytes
e4c23fc 98d8f14 e4c23fc 40a792a 6054f54 d4df4f1 e4c23fc 5f42f03 eba9007 5f42f03 e4c23fc 98d8f14 e4c23fc 5f42f03 e4c23fc 40a792a e4c23fc ad6777f e4c23fc 40a792a e4c23fc 40a792a ad6777f e4c23fc df67d7c e4c23fc 9309ea5 e123321 ab87187 3e144e5 ab87187 e4c23fc 255441a e4c23fc b5e86a6 255441a e4c23fc b5e86a6 e4c23fc 5f42f03 e4c23fc 9309ea5 255441a e4c23fc b5e86a6 9309ea5 255441a e4c23fc b5e86a6 e4c23fc f9d77f2 e4c23fc f9d77f2 40a792a ad6777f f9d77f2 e4c23fc f9d77f2 e4c23fc d4df4f1 5fcb7d4 d4df4f1 35ced66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from functools import partial
from flask import request, Response
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
@manager.route('/templates', methods=['GET'])
@login_required
def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
@manager.route('/list', methods=['GET'])
@login_required
def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in \
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
)
@manager.route('/rm', methods=['POST'])
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
if not UserCanvasService.query(user_id=current_user.id,id=i):
return get_json_result(
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)
@manager.route('/set', methods=['POST'])
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
req["user_id"] = current_user.id
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
if "id" not in req:
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip()):
return server_error_response(ValueError("Duplicated title."))
req["id"] = get_uuid()
if not UserCanvasService.save(**req):
return get_data_error_result(retmsg="Fail to save canvas.")
else:
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req)
return get_json_result(data=req)
@manager.route('/get/<canvas_id>', methods=['GET'])
@login_required
def get(canvas_id):
e, c = UserCanvasService.get_by_id(canvas_id)
if not e:
return get_data_error_result(retmsg="canvas not found.")
return get_json_result(data=c.to_dict())
@manager.route('/completion', methods=['POST'])
@validate_request("id")
@login_required
def run():
req = request.json
stream = req.get("stream", True)
e, cvs = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(retmsg="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
final_ans = {"reference": [], "content": ""}
message_id = req.get("message_id", get_uuid())
try:
canvas = Canvas(cvs.dsl, current_user.id)
if "message" in req:
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
#ten = TenantService.get_info_by(current_user.id)[0]
#req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
pass
canvas.add_user_input(req["message"])
answer = canvas.run(stream=stream)
print(canvas)
except Exception as e:
return server_error_response(e)
assert answer is not None, "Nothing. Is it over?"
if stream:
assert isinstance(answer, partial), "Nothing. Is it over?"
def sse():
nonlocal answer, cvs
try:
for ans in answer():
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
return get_json_result(data={"answer": final_ans["content"], "reference": final_ans.get("reference", [])})
@manager.route('/reset', methods=['POST'])
@validate_request("id")
@login_required
def reset():
req = request.json
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(retmsg="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)
@manager.route('/test_db_connect', methods=['POST'])
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
def test_db_connect():
req = request.json
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
elif req["db_type"] == 'postgresql':
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
db.connect()
db.close()
return get_json_result(data="Database Connection Successful!")
except Exception as e:
return server_error_response(e)
|