Spaces:
Runtime error
Runtime error
# ---------------------------------------- IMPORTS ---------------------------------------- # | |
# import flask and flask_cors to host the api | |
from flask import Flask, request, jsonify, render_template | |
from flask_cors import CORS | |
# import the api class | |
from assets.source import api, non_streamed_format | |
# import addon | |
from assets.source.addons import * # here we only use 'create_cloudflare_tunnel' and 'translate' from the addons | |
# logging module for debugging | |
import logging | |
# json module to parse json | |
from json import loads | |
# ---------------------------------------- CONFIGURE LOCAL SERVER ---------------------------------------- # | |
# create flask app | |
app = Flask(__name__) | |
app.template_folder = "assets/templates" | |
# enable cors | |
CORS(app) | |
# ---------------------------------------- READ FROM CONFIG FILE ---------------------------------------- # | |
with (open("assets/config.json", "r")) as f: | |
config_file = loads(f.read()) | |
# copy constants over | |
DEBUG: bool = config_file.get("DEBUG", False) | |
PORT: int = config_file.get("PORT", 5000) | |
HOST: str = config_file.get("HOST", "0.0.0.0") | |
# check if user wants to use a global server too | |
if config_file["use_global"]: | |
# create a cloudflare tunnel | |
create_cloudflare_tunnel(PORT) | |
# ---------------------------------------- LOGGING CONFIG ---------------------------------------- # | |
# set logging level | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s') | |
# ---------------------------------------- ROUTES ---------------------------------------- # | |
# chat generaiton route | |
def chat(): | |
# get request data | |
data = request.get_json() | |
# get messages | |
messages = message_translation(data["messages"]) if config_file["use_addons"] else data["messages"] | |
# get model | |
model = translate(data["model"]) if config_file["use_addons"] else data["model"] | |
# get max tokens | |
max_tokens = data.get("max_tokens", 150) | |
# top p and top k | |
top_p = data.get("top_p", 0.99) | |
top_k = data.get("top_k", 50) | |
# temperature, frequency penalty and presence penalty | |
temperature = data.get("temperature", 0.6) | |
# frequency penalty | |
frequency_penalty = data.get("frequency_penalty", 1) | |
# presence penalty | |
presence_penalty = data.get("presence_penalty", 1) | |
# streaming function. uses text/event-stream instead of application/json | |
def stream(): | |
# generate chat | |
for chunk in api.chat(messages, | |
model, | |
stream=True, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
temperature=temperature, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
top_k=top_k | |
): | |
# yield chat | |
#print(chunk) | |
yield chunk + b'\n\n' | |
# in the end, return done | |
yield b'data: [DONE]' | |
# check if user wants to stream | |
if data.get("stream"): | |
# log | |
logging.info(f"Streaming requested for model {model}\n") | |
# return stream | |
return app.response_class(stream(), mimetype='text/event-stream') | |
# even if not, stream but collect all data to a full string | |
else: | |
# log | |
logging.info(f"Non-streaming requested for model {model}\n") | |
# pre-init | |
full: str = "" | |
# generate chat | |
for chunk in api.chat(messages, | |
model, | |
stream=True, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
temperature=temperature, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
top_k=top_k | |
): | |
try: | |
# append chunk | |
full += loads(chunk.decode("utf-8").removeprefix('data: '))["choices"][0]["delta"]["content"] | |
except: pass | |
# return full | |
return jsonify(non_streamed_format(model, full)) | |
# route to get all models | |
def get_models(): | |
# return models | |
return jsonify(api.get_models()) | |
# root route to check if api is online | |
def root(): | |
# return root | |
return render_template("index.html") | |
# ---------------------------------------- ERROR HANDLING ---------------------------------------- # | |
def forbidden(error): | |
# return 403 | |
return jsonify( | |
{"status": False}, | |
{'error': [ | |
{'message': 'Something went wrong, the API was blocked from sending a request to the DeepInfra API. Please try again later.'}, | |
{'tpye': 'forbidden'}, | |
{'error': f'{error}'} | |
]}, | |
{'hint': 'please report issues on the github page'} | |
), 403 | |
def internal_server_error(error): | |
# return 500 | |
return jsonify( | |
{"status": False}, | |
{'error': [ | |
{'message': 'Something went wrong, the API was unable to complete your request. Please try again later.'}, | |
{'tpye': 'internal server error'}, | |
{'error': f'{error}'} | |
]}, | |
{'hint': 'please report issues on the github page'} | |
), 500 | |
# ---------------------------------------- START API ---------------------------------------- # | |
# start the api | |
if __name__ == "__main__": | |
app.run(debug=DEBUG, port=PORT, host=HOST) | |
# Path: app.py | |