Update main.py
Browse files
main.py
CHANGED
@@ -18,6 +18,7 @@ parser.add_argument("--port", type=int, help="Set the port.(default: 7860)", def
|
|
18 |
args = parser.parse_args()
|
19 |
|
20 |
base_url = os.getenv('MODEL_BASE_URL')
|
|
|
21 |
|
22 |
@app.route('/api/v1/models', methods=["GET", "POST"])
|
23 |
@app.route('/v1/models', methods=["GET", "POST"])
|
@@ -44,9 +45,11 @@ def model_list():
|
|
44 |
|
45 |
@app.route("/", methods=["GET"])
|
46 |
def index():
|
47 |
-
|
48 |
-
|
49 |
-
f'
|
|
|
|
|
50 |
|
51 |
@app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
|
52 |
@app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
|
@@ -72,6 +75,7 @@ def chat_completions():
|
|
72 |
message_size = len(messages)
|
73 |
|
74 |
prompt = messages[-1].get("content")
|
|
|
75 |
for i in range(message_size - 1):
|
76 |
role_this = messages[i].get("role")
|
77 |
role_next = messages[i + 1].get("role")
|
@@ -89,22 +93,41 @@ def chat_completions():
|
|
89 |
# print(f'{chat_history = }')
|
90 |
# print(f'{prompt = }')
|
91 |
|
92 |
-
fn_index =
|
93 |
|
94 |
-
# gen a random char(
|
95 |
chars = string.ascii_lowercase + string.digits
|
96 |
-
session_hash = "".join(random.choice(chars) for _ in range(
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
}
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
response = requests.post(f"{base_url}/queue/join", json=json_prompt)
|
106 |
-
url = f"{base_url}/queue/data?session_hash={session_hash}"
|
107 |
data = requests.get(url, stream=True)
|
|
|
108 |
|
109 |
time_now = int(time.time())
|
110 |
|
@@ -143,7 +166,7 @@ def gen_res_data(data, time_now=0, start=False):
|
|
143 |
if start:
|
144 |
res_data["choices"][0]["delta"] = {"role": "assistant", "content": ""}
|
145 |
else:
|
146 |
-
chat_pair = data["output"]["data"][
|
147 |
if chat_pair == []:
|
148 |
res_data["choices"][0]["finish_reason"] = "stop"
|
149 |
else:
|
@@ -152,5 +175,5 @@ def gen_res_data(data, time_now=0, start=False):
|
|
152 |
|
153 |
|
154 |
if __name__ == "__main__":
|
155 |
-
#
|
156 |
-
gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()
|
|
|
18 |
args = parser.parse_args()
|
19 |
|
20 |
base_url = os.getenv('MODEL_BASE_URL')
|
21 |
+
print(base_url)
|
22 |
|
23 |
@app.route('/api/v1/models', methods=["GET", "POST"])
|
24 |
@app.route('/v1/models', methods=["GET", "POST"])
|
|
|
45 |
|
46 |
@app.route("/", methods=["GET"])
|
47 |
def index():
|
48 |
+
print('index')
|
49 |
+
return Response(f"Hunyuan-Large OpenAI Compatible API<br><br>"+
|
50 |
+
f"Set '{os.getenv("SPACE_URL")}/api' as proxy (or API Domain) in your Chatbot.<br><br>"+
|
51 |
+
f"The complete API is: {os.getenv("SPACE_URL")}/api/v1/chat/completions<br><br>")
|
52 |
+
f"Don't set the Syetem Prompt. It will be ignored."
|
53 |
|
54 |
@app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
|
55 |
@app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
|
|
|
75 |
message_size = len(messages)
|
76 |
|
77 |
prompt = messages[-1].get("content")
|
78 |
+
|
79 |
for i in range(message_size - 1):
|
80 |
role_this = messages[i].get("role")
|
81 |
role_next = messages[i + 1].get("role")
|
|
|
93 |
# print(f'{chat_history = }')
|
94 |
# print(f'{prompt = }')
|
95 |
|
96 |
+
fn_index = 3
|
97 |
|
98 |
+
# gen a random char(10) hash
|
99 |
chars = string.ascii_lowercase + string.digits
|
100 |
+
session_hash = "".join(random.choice(chars) for _ in range(10))
|
101 |
+
|
102 |
+
single_prompt_data = {
|
103 |
+
'data': [
|
104 |
+
prompt,
|
105 |
+
[],
|
106 |
+
],
|
107 |
+
'event_data': None,
|
108 |
+
'fn_index': 1,
|
109 |
+
'trigger_id': 5,
|
110 |
+
'session_hash': session_hash,
|
111 |
+
}
|
112 |
+
response = requests.post(f'{base_url}/gradio_api/run/predict', json=single_prompt_data)
|
113 |
+
|
114 |
+
context_data = {
|
115 |
+
'data': [
|
116 |
+
None,
|
117 |
+
chat_history+[[prompt,None]]
|
118 |
+
],
|
119 |
+
'event_data': None,
|
120 |
+
'fn_index': fn_index,
|
121 |
+
'trigger_id': 5,
|
122 |
+
'session_hash': session_hash,
|
123 |
}
|
124 |
+
response = requests.post(f"{base_url}/gradio_api/queue/join", json=context_data)
|
125 |
+
|
126 |
+
def generate():
|
127 |
|
128 |
+
url = f"{base_url}/gradio_api/queue/data?session_hash={session_hash}"
|
|
|
|
|
129 |
data = requests.get(url, stream=True)
|
130 |
+
#print(data.text)
|
131 |
|
132 |
time_now = int(time.time())
|
133 |
|
|
|
166 |
if start:
|
167 |
res_data["choices"][0]["delta"] = {"role": "assistant", "content": ""}
|
168 |
else:
|
169 |
+
chat_pair = data["output"]["data"][0]
|
170 |
if chat_pair == []:
|
171 |
res_data["choices"][0]["finish_reason"] = "stop"
|
172 |
else:
|
|
|
175 |
|
176 |
|
177 |
if __name__ == "__main__":
|
178 |
+
#app.run(host=args.host, port=args.port, debug=True)
|
179 |
+
gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()
|