File size: 6,018 Bytes
315d075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# ---------------------------------------- IMPORTS ---------------------------------------- #

# import flask and flask_cors to host the api
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS

# import the api class
from assets.source import api, non_streamed_format

# import addon
from assets.source.addons import * # here we only use 'create_cloudflare_tunnel' and 'translate' from the addons

# logging module for debugging
import logging

# json module to parse json
from json import loads

# ---------------------------------------- CONFIGURE LOCAL SERVER ---------------------------------------- #

# create flask app
app = Flask(__name__)
app.template_folder = "assets/templates"

# enable cors
CORS(app)

# ---------------------------------------- READ FROM CONFIG FILE ---------------------------------------- #
with (open("assets/config.json", "r")) as f:

    config_file = loads(f.read())

    # copy constants over
    DEBUG: bool = config_file.get("DEBUG", False)
    PORT: int = config_file.get("PORT", 5000)
    HOST: str = config_file.get("HOST", "0.0.0.0")

    # check if user wants to use a global server too
    if config_file["use_global"]:

        # create a cloudflare tunnel
        create_cloudflare_tunnel(PORT)

# ---------------------------------------- LOGGING CONFIG ---------------------------------------- #

# set logging level
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s')

# ---------------------------------------- ROUTES ---------------------------------------- #

# chat generaiton route
@app.route("/chat/completions", methods=["POST"])
def chat():

    # get request data
    data = request.get_json()

    # get messages
    messages = message_translation(data["messages"]) if config_file["use_addons"] else data["messages"]
        
    # get model
    model = translate(data["model"]) if config_file["use_addons"] else data["model"]

    # get max tokens
    max_tokens = data.get("max_tokens", 150)

    # top p and top k
    top_p = data.get("top_p", 0.99)
    top_k = data.get("top_k", 50)

    # temperature, frequency penalty and presence penalty
    temperature = data.get("temperature", 0.6)

    # frequency penalty
    frequency_penalty = data.get("frequency_penalty", 1)

    # presence penalty
    presence_penalty = data.get("presence_penalty", 1)

    # streaming function. uses text/event-stream instead of application/json
    def stream():

        # generate chat
        for chunk in api.chat(messages, 
                              model, 
                              stream=True, 
                              max_tokens=max_tokens, 
                              top_p=top_p, 
                              temperature=temperature, 
                              frequency_penalty=frequency_penalty, 
                              presence_penalty=presence_penalty,
                                top_k=top_k
        ):

            # yield chat
            #print(chunk)
            yield chunk + b'\n\n'

        # in the end, return done
        yield b'data: [DONE]'

    # check if user wants to stream
    if data.get("stream"):

        # log
        logging.info(f"Streaming requested for model {model}\n")

        # return stream
        return app.response_class(stream(), mimetype='text/event-stream')
    
    # even if not, stream but collect all data to a full string
    else:

        # log
        logging.info(f"Non-streaming requested for model {model}\n")

        # pre-init
        full: str = ""

        # generate chat
        for chunk in api.chat(messages, 
                              model, 
                              stream=True, 
                              max_tokens=max_tokens, 
                              top_p=top_p, 
                              temperature=temperature, 
                              frequency_penalty=frequency_penalty, 
                              presence_penalty=presence_penalty,
                              top_k=top_k
        ):

            try:

                # append chunk
                full += loads(chunk.decode("utf-8").removeprefix('data: '))["choices"][0]["delta"]["content"]

            except: pass

        # return full
        return jsonify(non_streamed_format(model, full))


# route to get all models
@app.route("/models", methods=["GET"])
def get_models():

    # return models
    return jsonify(api.get_models())

# root route to check if api is online
@app.route("/", methods=["GET"])
def root():
    
        # return root
        return render_template("index.html")

# ---------------------------------------- ERROR HANDLING ---------------------------------------- #
@app.errorhandler(403)
def forbidden(error):

    # return 403
    return jsonify(

        {"status": False},
        {'error': [

            {'message': 'Something went wrong, the API was blocked from sending a request to the DeepInfra API. Please try again later.'},
            {'tpye': 'forbidden'},
            {'error': f'{error}'}
        ]},
        {'hint': 'please report issues on the github page'}
    ), 403

@app.errorhandler(500)
def internal_server_error(error):
    
        # return 500
        return jsonify(
    
            {"status": False},
            {'error': [
    
                {'message': 'Something went wrong, the API was unable to complete your request. Please try again later.'},
                {'tpye': 'internal server error'},
                {'error': f'{error}'}
            ]},
            {'hint': 'please report issues on the github page'}
        ), 500

# ---------------------------------------- START API ---------------------------------------- #

# start the api
if __name__ == "__main__":

    app.run(debug=DEBUG, port=PORT, host=HOST)

# Path: app.py