antonovmaxim commited on
Commit
883ac62
·
1 Parent(s): 2eac015
Files changed (6) hide show
  1. README.md +1 -1
  2. blocking_api.py +131 -0
  3. gpt4.py +32 -0
  4. requirements.txt +3 -0
  5. script.py +10 -0
  6. util.py +92 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: green
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.33.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.33.1
8
+ app_file: script.py
9
  pinned: false
10
  license: mit
11
  ---
blocking_api.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
3
+ from threading import Thread
4
+
5
+ from util import build_parameters, try_start_cloudflared
6
+ from gpt4 import ask_gpt
7
+ # from modules import shared
8
+ # from modules.chat import generate_chat_reply
9
+ # from modules.text_generation import encode, generate_reply, stop_everything_event
10
+
11
+
12
+ class Handler(BaseHTTPRequestHandler):
13
+ def do_GET(self):
14
+ if self.path == '/api/v1/model':
15
+ self.send_response(200)
16
+ self.end_headers()
17
+ response = json.dumps({
18
+ 'result': 'GPT4 mindsdb OpenAI original'
19
+ })
20
+
21
+ self.wfile.write(response.encode('utf-8'))
22
+ else:
23
+ self.send_error(404)
24
+
25
+ def do_POST(self):
26
+ content_length = int(self.headers['Content-Length'])
27
+ body = json.loads(self.rfile.read(content_length).decode('utf-8'))
28
+
29
+ if self.path == '/api/v1/generate':
30
+ self.send_response(200)
31
+ self.send_header('Content-Type', 'application/json')
32
+ self.end_headers()
33
+
34
+ prompt = body['prompt']
35
+ generate_params = build_parameters(body)
36
+ stopping_strings = generate_params.pop('stopping_strings')
37
+ generate_params['stream'] = False
38
+
39
+ # generator = generate_reply(
40
+ # prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
41
+
42
+ answer = ask_gpt(prompt)
43
+
44
+ response = json.dumps({
45
+ 'results': [{
46
+ 'text': answer
47
+ }]
48
+ })
49
+
50
+ self.wfile.write(response.encode('utf-8'))
51
+
52
+ elif self.path == '/api/v1/chat':
53
+ self.send_response(200)
54
+ self.send_header('Content-Type', 'application/json')
55
+ self.end_headers()
56
+
57
+ user_input = body['user_input']
58
+ history = body['history']
59
+ regenerate = body.get('regenerate', False)
60
+ _continue = body.get('_continue', False)
61
+
62
+ generate_params = build_parameters(body, chat=True)
63
+ generate_params['stream'] = False
64
+ generator = 'error'
65
+ # generator = generate_chat_reply(
66
+ # user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
67
+
68
+ answer = history
69
+ for a in generator:
70
+ answer = a
71
+
72
+ response = json.dumps({
73
+ 'results': [{
74
+ 'history': answer
75
+ }]
76
+ })
77
+
78
+ self.wfile.write(response.encode('utf-8'))
79
+
80
+ elif self.path == '/api/v1/stop-stream':
81
+ self.send_response(200)
82
+ self.send_header('Content-Type', 'application/json')
83
+ self.end_headers()
84
+
85
+ # stop_everything_event()
86
+
87
+ response = json.dumps({
88
+ 'results': 'error'
89
+ })
90
+
91
+ self.wfile.write(response.encode('utf-8'))
92
+
93
+ elif self.path == '/api/v1/token-count':
94
+ self.send_response(200)
95
+ self.send_header('Content-Type', 'application/json')
96
+ self.end_headers()
97
+
98
+ # tokens = encode(body['prompt'])[0]
99
+ response = json.dumps({
100
+ 'results': [{
101
+ 'tokens': 'error'
102
+ }]
103
+ })
104
+
105
+ self.wfile.write(response.encode('utf-8'))
106
+ else:
107
+ self.send_error(404)
108
+
109
+
110
+ def _run_server(port: int, share: bool = False):
111
+ address = '0.0.0.0' if 0 else '127.0.0.1'
112
+
113
+ server = ThreadingHTTPServer((address, port), Handler)
114
+
115
+ def on_start(public_url: str):
116
+ print(f'Starting non-streaming server at public url {public_url}/api')
117
+
118
+ if share:
119
+ try:
120
+ try_start_cloudflared(port, max_attempts=3, on_start=on_start)
121
+ except Exception:
122
+ pass
123
+ else:
124
+ print(
125
+ f'Starting API at http://{address}:{port}/api')
126
+
127
+ server.serve_forever()
128
+
129
+
130
+ def start_server(port: int, share: bool = False):
131
+ Thread(target=_run_server, args=[port, share], daemon=True).start()
gpt4.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pymysql
2
+ import os
3
+
4
+ def connect_to_db():
5
+ try:
6
+ connection = pymysql.connect(
7
+ host="cloud.mindsdb.com",
8
+ user=os.environ['LOGIN'],
9
+ password=os.environ['PASSWORD'],
10
+ port=3306,
11
+ db="mindsdb",
12
+ charset="utf8mb4",
13
+ cursorclass=pymysql.cursors.DictCursor)
14
+ return connection
15
+ except Exception as e:
16
+ print(f"Error connecting to database: {e}")
17
+ return None
18
+ print('---Trying to connect---')
19
+ connection = connect_to_db()
20
+ print('---CONNECTED---')
21
+
22
+ def ask_gpt(text):
23
+ # print('ASK GPT:', text)
24
+ text = text.replace("'", '"')
25
+
26
+ query = f"SELECT response FROM mindsdb.gpt4 WHERE text=%s"
27
+
28
+ cursor = connection.cursor()
29
+ cursor.execute(query, (text, ))
30
+ result = cursor.fetchone()
31
+ response = result['response']
32
+ return response
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ flask_cloudflared==0.0.12
2
+ websockets==11.0.2
3
+ pymysql
script.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import blocking_api
2
+ # import extensions.api.streaming_api as streaming_api
3
+ # from modules import shared
4
+
5
+
6
+ def setup():
7
+ blocking_api.start_server(7860, share=True)
8
+ # streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)
9
+ while True: pass
10
+ setup()
util.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+ from threading import Thread
4
+ from typing import Callable, Optional
5
+
6
+ # from modules import shared
7
+ # from modules.chat import load_character_memoized
8
+ shared = True
9
+ load_character_memoized = lambda: 0
10
+ def build_parameters(body, chat=False):
11
+
12
+ generate_params = {
13
+ 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
14
+ 'do_sample': bool(body.get('do_sample', True)),
15
+ 'temperature': float(body.get('temperature', 0.5)),
16
+ 'top_p': float(body.get('top_p', 1)),
17
+ 'typical_p': float(body.get('typical_p', body.get('typical', 1))),
18
+ 'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
19
+ 'eta_cutoff': float(body.get('eta_cutoff', 0)),
20
+ 'tfs': float(body.get('tfs', 1)),
21
+ 'top_a': float(body.get('top_a', 0)),
22
+ 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
23
+ 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
24
+ 'top_k': int(body.get('top_k', 0)),
25
+ 'min_length': int(body.get('min_length', 0)),
26
+ 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
27
+ 'num_beams': int(body.get('num_beams', 1)),
28
+ 'penalty_alpha': float(body.get('penalty_alpha', 0)),
29
+ 'length_penalty': float(body.get('length_penalty', 1)),
30
+ 'early_stopping': bool(body.get('early_stopping', False)),
31
+ 'mirostat_mode': int(body.get('mirostat_mode', 0)),
32
+ 'mirostat_tau': float(body.get('mirostat_tau', 5)),
33
+ 'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
34
+ 'seed': int(body.get('seed', -1)),
35
+ 'add_bos_token': bool(body.get('add_bos_token', True)),
36
+ 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
37
+ 'ban_eos_token': bool(body.get('ban_eos_token', False)),
38
+ 'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
39
+ 'custom_stopping_strings': '', # leave this blank
40
+ 'stopping_strings': body.get('stopping_strings', []),
41
+ }
42
+
43
+ if chat:
44
+ character = body.get('character')
45
+ instruction_template = body.get('instruction_template')
46
+ name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False)
47
+ name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
48
+ generate_params.update({
49
+ 'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
50
+ 'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
51
+ 'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
52
+ 'mode': str(body.get('mode', 'chat')),
53
+ 'name1': name1,
54
+ 'name2': name2,
55
+ 'context': context,
56
+ 'greeting': greeting,
57
+ 'name1_instruct': name1_instruct,
58
+ 'name2_instruct': name2_instruct,
59
+ 'context_instruct': context_instruct,
60
+ 'turn_template': turn_template,
61
+ 'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
62
+ })
63
+
64
+ return generate_params
65
+
66
+
67
+ def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
68
+ Thread(target=_start_cloudflared, args=[
69
+ port, max_attempts, on_start], daemon=True).start()
70
+
71
+
72
+ def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
73
+ try:
74
+ from flask_cloudflared import _run_cloudflared
75
+ except ImportError:
76
+ print('You should install flask_cloudflared manually')
77
+ raise Exception(
78
+ 'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
79
+
80
+ for _ in range(max_attempts):
81
+ try:
82
+ public_url = _run_cloudflared(port, port + 1)
83
+
84
+ if on_start:
85
+ on_start(public_url)
86
+
87
+ return
88
+ except Exception:
89
+ traceback.print_exc()
90
+ time.sleep(3)
91
+
92
+ raise Exception('Could not start cloudflared.')