Spaces:
Running
Running
antonovmaxim
commited on
Commit
·
883ac62
1
Parent(s):
2eac015
all files
Browse files- README.md +1 -1
- blocking_api.py +131 -0
- gpt4.py +32 -0
- requirements.txt +3 -0
- script.py +10 -0
- util.py +92 -0
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: green
|
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.33.1
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
|
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.33.1
|
8 |
+
app_file: script.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
blocking_api.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
3 |
+
from threading import Thread
|
4 |
+
|
5 |
+
from util import build_parameters, try_start_cloudflared
|
6 |
+
from gpt4 import ask_gpt
|
7 |
+
# from modules import shared
|
8 |
+
# from modules.chat import generate_chat_reply
|
9 |
+
# from modules.text_generation import encode, generate_reply, stop_everything_event
|
10 |
+
|
11 |
+
|
12 |
+
class Handler(BaseHTTPRequestHandler):
|
13 |
+
def do_GET(self):
|
14 |
+
if self.path == '/api/v1/model':
|
15 |
+
self.send_response(200)
|
16 |
+
self.end_headers()
|
17 |
+
response = json.dumps({
|
18 |
+
'result': 'GPT4 mindsdb OpenAI original'
|
19 |
+
})
|
20 |
+
|
21 |
+
self.wfile.write(response.encode('utf-8'))
|
22 |
+
else:
|
23 |
+
self.send_error(404)
|
24 |
+
|
25 |
+
def do_POST(self):
|
26 |
+
content_length = int(self.headers['Content-Length'])
|
27 |
+
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
28 |
+
|
29 |
+
if self.path == '/api/v1/generate':
|
30 |
+
self.send_response(200)
|
31 |
+
self.send_header('Content-Type', 'application/json')
|
32 |
+
self.end_headers()
|
33 |
+
|
34 |
+
prompt = body['prompt']
|
35 |
+
generate_params = build_parameters(body)
|
36 |
+
stopping_strings = generate_params.pop('stopping_strings')
|
37 |
+
generate_params['stream'] = False
|
38 |
+
|
39 |
+
# generator = generate_reply(
|
40 |
+
# prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
41 |
+
|
42 |
+
answer = ask_gpt(prompt)
|
43 |
+
|
44 |
+
response = json.dumps({
|
45 |
+
'results': [{
|
46 |
+
'text': answer
|
47 |
+
}]
|
48 |
+
})
|
49 |
+
|
50 |
+
self.wfile.write(response.encode('utf-8'))
|
51 |
+
|
52 |
+
elif self.path == '/api/v1/chat':
|
53 |
+
self.send_response(200)
|
54 |
+
self.send_header('Content-Type', 'application/json')
|
55 |
+
self.end_headers()
|
56 |
+
|
57 |
+
user_input = body['user_input']
|
58 |
+
history = body['history']
|
59 |
+
regenerate = body.get('regenerate', False)
|
60 |
+
_continue = body.get('_continue', False)
|
61 |
+
|
62 |
+
generate_params = build_parameters(body, chat=True)
|
63 |
+
generate_params['stream'] = False
|
64 |
+
generator = 'error'
|
65 |
+
# generator = generate_chat_reply(
|
66 |
+
# user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
67 |
+
|
68 |
+
answer = history
|
69 |
+
for a in generator:
|
70 |
+
answer = a
|
71 |
+
|
72 |
+
response = json.dumps({
|
73 |
+
'results': [{
|
74 |
+
'history': answer
|
75 |
+
}]
|
76 |
+
})
|
77 |
+
|
78 |
+
self.wfile.write(response.encode('utf-8'))
|
79 |
+
|
80 |
+
elif self.path == '/api/v1/stop-stream':
|
81 |
+
self.send_response(200)
|
82 |
+
self.send_header('Content-Type', 'application/json')
|
83 |
+
self.end_headers()
|
84 |
+
|
85 |
+
# stop_everything_event()
|
86 |
+
|
87 |
+
response = json.dumps({
|
88 |
+
'results': 'error'
|
89 |
+
})
|
90 |
+
|
91 |
+
self.wfile.write(response.encode('utf-8'))
|
92 |
+
|
93 |
+
elif self.path == '/api/v1/token-count':
|
94 |
+
self.send_response(200)
|
95 |
+
self.send_header('Content-Type', 'application/json')
|
96 |
+
self.end_headers()
|
97 |
+
|
98 |
+
# tokens = encode(body['prompt'])[0]
|
99 |
+
response = json.dumps({
|
100 |
+
'results': [{
|
101 |
+
'tokens': 'error'
|
102 |
+
}]
|
103 |
+
})
|
104 |
+
|
105 |
+
self.wfile.write(response.encode('utf-8'))
|
106 |
+
else:
|
107 |
+
self.send_error(404)
|
108 |
+
|
109 |
+
|
110 |
+
def _run_server(port: int, share: bool = False):
|
111 |
+
address = '0.0.0.0' if 0 else '127.0.0.1'
|
112 |
+
|
113 |
+
server = ThreadingHTTPServer((address, port), Handler)
|
114 |
+
|
115 |
+
def on_start(public_url: str):
|
116 |
+
print(f'Starting non-streaming server at public url {public_url}/api')
|
117 |
+
|
118 |
+
if share:
|
119 |
+
try:
|
120 |
+
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
|
121 |
+
except Exception:
|
122 |
+
pass
|
123 |
+
else:
|
124 |
+
print(
|
125 |
+
f'Starting API at http://{address}:{port}/api')
|
126 |
+
|
127 |
+
server.serve_forever()
|
128 |
+
|
129 |
+
|
130 |
+
def start_server(port: int, share: bool = False):
|
131 |
+
Thread(target=_run_server, args=[port, share], daemon=True).start()
|
gpt4.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pymysql
|
2 |
+
import os
|
3 |
+
|
4 |
+
def connect_to_db():
|
5 |
+
try:
|
6 |
+
connection = pymysql.connect(
|
7 |
+
host="cloud.mindsdb.com",
|
8 |
+
user=os.environ['LOGIN'],
|
9 |
+
password=os.environ['PASSWORD'],
|
10 |
+
port=3306,
|
11 |
+
db="mindsdb",
|
12 |
+
charset="utf8mb4",
|
13 |
+
cursorclass=pymysql.cursors.DictCursor)
|
14 |
+
return connection
|
15 |
+
except Exception as e:
|
16 |
+
print(f"Error connecting to database: {e}")
|
17 |
+
return None
|
18 |
+
print('---Trying to connect---')
|
19 |
+
connection = connect_to_db()
|
20 |
+
print('---CONNECTED---')
|
21 |
+
|
22 |
+
def ask_gpt(text):
|
23 |
+
# print('ASK GPT:', text)
|
24 |
+
text = text.replace("'", '"')
|
25 |
+
|
26 |
+
query = f"SELECT response FROM mindsdb.gpt4 WHERE text=%s"
|
27 |
+
|
28 |
+
cursor = connection.cursor()
|
29 |
+
cursor.execute(query, (text, ))
|
30 |
+
result = cursor.fetchone()
|
31 |
+
response = result['response']
|
32 |
+
return response
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
flask_cloudflared==0.0.12
|
2 |
+
websockets==11.0.2
|
3 |
+
pymysql
|
script.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import blocking_api
|
2 |
+
# import extensions.api.streaming_api as streaming_api
|
3 |
+
# from modules import shared
|
4 |
+
|
5 |
+
|
6 |
+
def setup():
|
7 |
+
blocking_api.start_server(7860, share=True)
|
8 |
+
# streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)
|
9 |
+
while True: pass
|
10 |
+
setup()
|
util.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import traceback
|
3 |
+
from threading import Thread
|
4 |
+
from typing import Callable, Optional
|
5 |
+
|
6 |
+
# from modules import shared
|
7 |
+
# from modules.chat import load_character_memoized
|
8 |
+
shared = True
|
9 |
+
load_character_memoized = lambda: 0
|
10 |
+
def build_parameters(body, chat=False):
|
11 |
+
|
12 |
+
generate_params = {
|
13 |
+
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
14 |
+
'do_sample': bool(body.get('do_sample', True)),
|
15 |
+
'temperature': float(body.get('temperature', 0.5)),
|
16 |
+
'top_p': float(body.get('top_p', 1)),
|
17 |
+
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
|
18 |
+
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
|
19 |
+
'eta_cutoff': float(body.get('eta_cutoff', 0)),
|
20 |
+
'tfs': float(body.get('tfs', 1)),
|
21 |
+
'top_a': float(body.get('top_a', 0)),
|
22 |
+
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
|
23 |
+
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
|
24 |
+
'top_k': int(body.get('top_k', 0)),
|
25 |
+
'min_length': int(body.get('min_length', 0)),
|
26 |
+
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
27 |
+
'num_beams': int(body.get('num_beams', 1)),
|
28 |
+
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
29 |
+
'length_penalty': float(body.get('length_penalty', 1)),
|
30 |
+
'early_stopping': bool(body.get('early_stopping', False)),
|
31 |
+
'mirostat_mode': int(body.get('mirostat_mode', 0)),
|
32 |
+
'mirostat_tau': float(body.get('mirostat_tau', 5)),
|
33 |
+
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
|
34 |
+
'seed': int(body.get('seed', -1)),
|
35 |
+
'add_bos_token': bool(body.get('add_bos_token', True)),
|
36 |
+
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
37 |
+
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
38 |
+
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
39 |
+
'custom_stopping_strings': '', # leave this blank
|
40 |
+
'stopping_strings': body.get('stopping_strings', []),
|
41 |
+
}
|
42 |
+
|
43 |
+
if chat:
|
44 |
+
character = body.get('character')
|
45 |
+
instruction_template = body.get('instruction_template')
|
46 |
+
name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False)
|
47 |
+
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
|
48 |
+
generate_params.update({
|
49 |
+
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
|
50 |
+
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
|
51 |
+
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
|
52 |
+
'mode': str(body.get('mode', 'chat')),
|
53 |
+
'name1': name1,
|
54 |
+
'name2': name2,
|
55 |
+
'context': context,
|
56 |
+
'greeting': greeting,
|
57 |
+
'name1_instruct': name1_instruct,
|
58 |
+
'name2_instruct': name2_instruct,
|
59 |
+
'context_instruct': context_instruct,
|
60 |
+
'turn_template': turn_template,
|
61 |
+
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
|
62 |
+
})
|
63 |
+
|
64 |
+
return generate_params
|
65 |
+
|
66 |
+
|
67 |
+
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
68 |
+
Thread(target=_start_cloudflared, args=[
|
69 |
+
port, max_attempts, on_start], daemon=True).start()
|
70 |
+
|
71 |
+
|
72 |
+
def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
73 |
+
try:
|
74 |
+
from flask_cloudflared import _run_cloudflared
|
75 |
+
except ImportError:
|
76 |
+
print('You should install flask_cloudflared manually')
|
77 |
+
raise Exception(
|
78 |
+
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
79 |
+
|
80 |
+
for _ in range(max_attempts):
|
81 |
+
try:
|
82 |
+
public_url = _run_cloudflared(port, port + 1)
|
83 |
+
|
84 |
+
if on_start:
|
85 |
+
on_start(public_url)
|
86 |
+
|
87 |
+
return
|
88 |
+
except Exception:
|
89 |
+
traceback.print_exc()
|
90 |
+
time.sleep(3)
|
91 |
+
|
92 |
+
raise Exception('Could not start cloudflared.')
|