|
import json |
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer |
|
from threading import Thread |
|
|
|
from extensions.api.util import build_parameters, try_start_cloudflared |
|
from modules import shared |
|
from modules.text_generation import encode, generate_reply |
|
|
|
|
|
class Handler(BaseHTTPRequestHandler): |
|
def do_GET(self): |
|
if self.path == '/api/v1/model': |
|
self.send_response(200) |
|
self.end_headers() |
|
response = json.dumps({ |
|
'result': shared.model_name |
|
}) |
|
|
|
self.wfile.write(response.encode('utf-8')) |
|
else: |
|
self.send_error(404) |
|
|
|
def do_POST(self): |
|
content_length = int(self.headers['Content-Length']) |
|
body = json.loads(self.rfile.read(content_length).decode('utf-8')) |
|
|
|
if self.path == '/api/v1/generate': |
|
self.send_response(200) |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
prompt = body['prompt'] |
|
generate_params = build_parameters(body) |
|
stopping_strings = generate_params.pop('stopping_strings') |
|
generate_params['stream'] = False |
|
|
|
generator = generate_reply( |
|
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) |
|
|
|
answer = '' |
|
for a in generator: |
|
answer = a |
|
|
|
response = json.dumps({ |
|
'results': [{ |
|
'text': answer |
|
}] |
|
}) |
|
self.wfile.write(response.encode('utf-8')) |
|
elif self.path == '/api/v1/token-count': |
|
self.send_response(200) |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
tokens = encode(body['prompt'])[0] |
|
response = json.dumps({ |
|
'results': [{ |
|
'tokens': len(tokens) |
|
}] |
|
}) |
|
self.wfile.write(response.encode('utf-8')) |
|
else: |
|
self.send_error(404) |
|
|
|
|
|
def _run_server(port: int, share: bool = False): |
|
address = '0.0.0.0' if shared.args.listen else '127.0.0.1' |
|
|
|
server = ThreadingHTTPServer((address, port), Handler) |
|
|
|
def on_start(public_url: str): |
|
print(f'Starting non-streaming server at public url {public_url}/api') |
|
|
|
if share: |
|
try: |
|
try_start_cloudflared(port, max_attempts=3, on_start=on_start) |
|
except Exception: |
|
pass |
|
else: |
|
print( |
|
f'Starting API at http://{address}:{port}/api') |
|
|
|
server.serve_forever() |
|
|
|
|
|
def start_server(port: int, share: bool = False): |
|
Thread(target=_run_server, args=[port, share], daemon=True).start() |
|
|