gpt4 / blocking_api.py
antonovmaxim's picture
instructions
0287610
raw
history blame contribute delete
No virus
4.33 kB
import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from util import build_parameters, try_start_cloudflared
from gpt4 import ask_gpt
# from modules import shared
# from modules.chat import generate_chat_reply
# from modules.text_generation import encode, generate_reply, stop_everything_event
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': 'GPT4 mindsdb OpenAI original'
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = False
# generator = generate_reply(
# prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
answer = ask_gpt(prompt)
response = json.dumps({
'results': [{
'text': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
history = body['history']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = 'error'
# generator = generate_chat_reply(
# user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = history
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/stop-stream':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
# stop_everything_event()
response = json.dumps({
'results': 'error'
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
# tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': 'error'
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def _run_server(port: int, share: bool = False):
address = '0.0.0.0' if 0 else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)
def on_start(public_url: str):
with open('main.md', 'r') as f:
text = f.read()
text = text.replace("[located in the logs of this container]", f"{public_url}/api")
with open('main.md', 'w') as f:
f.write(text)
print(f'Starting non-streaming server at public url {public_url}/api')
if share:
try:
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
except Exception:
pass
else:
print(
f'Starting API at http://{address}:{port}/api')
server.serve_forever()
def start_server(port: int, share: bool = False):
Thread(target=_run_server, args=[port, share], daemon=True).start()