|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Word embeddings server class.""" |
|
|
|
import email.parser |
|
import http.server |
|
import json |
|
import random |
|
import socketserver |
|
import sys |
|
import threading |
|
import time |
|
import urllib.error |
|
import urllib.parse |
|
import urllib.request |
|
|
|
|
|
class FrontendRESTServer(socketserver.TCPServer): |
|
class Backend(): |
|
def __init__(self, server): |
|
self._server = server |
|
|
|
with self.request("/models") as response: |
|
data = json.loads(response.read()) |
|
|
|
assert "models" in data and isinstance(data["models"], dict) |
|
self.models = data["models"] |
|
|
|
assert "default_model" in data and isinstance(data["default_model"], str) |
|
self.default_model = data["default_model"] |
|
|
|
def request(self, url, data=None, data_content_type=None): |
|
return urllib.request.urlopen(urllib.request.Request( |
|
url="http://{}{}".format(self._server, url), |
|
data=data, |
|
headers={} if data is None else {"Content-Type": data_content_type}, |
|
)) |
|
|
|
|
|
class FrontendRESTServer(http.server.BaseHTTPRequestHandler): |
|
protocol_version = "HTTP/1.1" |
|
|
|
format_for_log_table = str.maketrans("\n", "\r", "\r") |
|
def format_for_log(request, data, limit=None): |
|
if limit is not None: |
|
if limit <= 0: |
|
data = "[{}B]".format(len(data)) |
|
elif len(data) > limit: |
|
data = data[:limit // 2] + " ... " + data[min(-1, -limit // 2):] |
|
return data.translate(request.format_for_log_table) |
|
|
|
def respond(request, content_type, code=200, additional_headers={}): |
|
request.close_connection = True |
|
request.send_response(code) |
|
request.send_header("Connection", "close") |
|
request.send_header("Content-Type", content_type) |
|
request.send_header("Access-Control-Allow-Origin", "*") |
|
for key, value in additional_headers.items(): |
|
request.send_header(key, value) |
|
request.end_headers() |
|
|
|
def respond_error(request, message, code=400): |
|
request.respond("text/plain", code) |
|
request.wfile.write(message.encode("utf-8")) |
|
|
|
def handle_expect_100(request): |
|
try: |
|
request_too_long = int(request.headers["Content-Length"]) > request.server._args.max_request_size |
|
except: |
|
request_too_long = False |
|
|
|
if request_too_long: |
|
request.respond_error("The payload size is too large.") |
|
return False |
|
return super().handle_expect_100() |
|
|
|
def do_GET(request): |
|
|
|
params, body, body_content_type = {}, None, None |
|
try: |
|
encoded_path = request.path.encode("iso-8859-1").decode("utf-8") |
|
url = urllib.parse.urlparse(encoded_path) |
|
for name, value in urllib.parse.parse_qsl(url.query, encoding="utf-8", keep_blank_values=True, errors="strict"): |
|
params[name] = value |
|
except: |
|
return request.respond_error("Cannot parse request URL.") |
|
|
|
|
|
if request.command == "POST": |
|
if request.headers.get("Transfer-Encoding", "identity").lower() != "identity": |
|
return request.respond_error("Only 'identity' Transfer-Encoding of payload is supported for now.") |
|
|
|
try: |
|
content_length = int(request.headers["Content-Length"]) |
|
except: |
|
return request.respond_error("The Content-Length of payload is required.") |
|
|
|
if content_length > request.server._args.max_request_size: |
|
while content_length: |
|
read = request.rfile.read(min(content_length, 65536)) |
|
content_length -= len(read) if read else content_length |
|
return request.respond_error("The payload size is too large.") |
|
|
|
body = request.rfile.read(content_length) |
|
body_content_type = request.headers.get("Content-Type", "") |
|
|
|
|
|
if request.headers.get("Content-Type", "").startswith("multipart/form-data"): |
|
try: |
|
parser = email.parser.BytesFeedParser() |
|
parser.feed(b"Content-Type: " + request.headers["Content-Type"].encode("ascii") + b"\r\n\r\n") |
|
parser.feed(body) |
|
for part in parser.close().get_payload(): |
|
name = part.get_param("name", header="Content-Disposition") |
|
if name: |
|
params[name] = part.get_payload(decode=True).decode("utf-8") |
|
except: |
|
return request.respond_error("Cannot parse the multipart/form-data payload.") |
|
|
|
elif request.headers.get("Content-Type", "").startswith("application/x-www-form-urlencoded"): |
|
try: |
|
for name, value in urllib.parse.parse_qsl( |
|
body.decode("utf-8"), encoding="utf-8", keep_blank_values=True, errors="strict"): |
|
params[name] = value |
|
except: |
|
return request.respond_error("Cannot parse the application/x-www-form-urlencoded payload.") |
|
|
|
|
|
if request.server._args.log_data: |
|
print(url.path, " ".join(request.headers.get_all("X-Forwarded-For", [])), |
|
*["{}:{}".format(key, request.format_for_log(value)) for key, value in params.items() if key != "data"], |
|
"data:" + request.format_for_log(params.get("data", ""), request.server._args.log_data), |
|
sep="\t", file=sys.stderr, flush=True) |
|
|
|
|
|
if url.path == "/models": |
|
response = { |
|
"models": {name: value for backend in request.server.backends for name, value in backend.models.items()}, |
|
"default_model": request.server.backends[0].default_model, |
|
} |
|
request.respond("application/json") |
|
request.wfile.write(json.dumps(response, indent=1).encode("utf-8")) |
|
|
|
else: |
|
|
|
backends = request.server.backends.copy() |
|
model = params.get("model", request.server.backends[0].default_model) |
|
if model in request.server.aliases: |
|
resolved_model = request.server.aliases[model] |
|
backends = [backend for backend in request.server.backends if resolved_model in backend.models] or backends |
|
|
|
|
|
started_responding = False |
|
try: |
|
assert backends, "No backends found!" |
|
while backends: |
|
backend = random.choice(backends) if len(backends) > 1 else backends[0] |
|
backends.remove(backend) |
|
try: |
|
with backend.request(request.path, body, body_content_type) as response: |
|
while True: |
|
data = response.read(32768) |
|
if not started_responding: |
|
started_responding = True |
|
billing_infclen = response.getheader("X-Billing-Input-NFC-Len", None) |
|
headers = {"X-Billing-Input-NFC-Len": billing_infclen} if billing_infclen is not None else {} |
|
request.respond(response.getheader("Content-Type", "application/json"), code=response.code, |
|
additional_headers=headers) |
|
if len(data) == 0: break |
|
request.wfile.write(data) |
|
except urllib.error.HTTPError as error: |
|
if not started_responding: |
|
started_responding = True |
|
request.respond(error.headers.get("Content-Type", "text/plain"), code=error.code) |
|
request.wfile.write(error.file.read()) |
|
break |
|
raise |
|
except: |
|
if backends and not started_responding: |
|
import traceback |
|
traceback.print_exc(file=sys.stderr) |
|
print("The above error occurred during request processing on '{}',".format(backend._server), |
|
"but more backends are available, retrying.", file=sys.stderr, flush=True) |
|
continue |
|
raise |
|
break |
|
except: |
|
import traceback |
|
traceback.print_exc(file=sys.stderr) |
|
sys.stderr.flush() |
|
|
|
if not started_responding: |
|
request.respond_error("An internal error occurred during processing.") |
|
else: |
|
request.wfile.write(b'",\n"An internal error occurred during processing, producing incorrect JSON!"') |
|
|
|
def do_POST(request): |
|
return request.do_GET() |
|
|
|
def __init__(self, args): |
|
self._args = args |
|
|
|
|
|
self.backends = [self.Backend(backend) for backend in args.backends] |
|
|
|
|
|
self.aliases = {} |
|
if args.aliases is not None: |
|
with open(args.aliases, "r", encoding="utf-8") as aliases_file: |
|
for line in aliases_file: |
|
line = line.rstrip("\r\n") |
|
if not line or line.startswith("#"): |
|
continue |
|
parts = line.split() |
|
assert len(parts) in [3, 4], "Expected 3-4 columns in the aliases file: line '{}'".format(line) |
|
names = parts[0].split(":") |
|
for name in names: |
|
parts = name.split("-") |
|
for prefix in ("-".join(parts[:None if not i else -i]) for i in range(len(parts))): |
|
self.aliases.setdefault(prefix, names[0]) |
|
|
|
|
|
self._threads = [] |
|
super().__init__(("", self._args.port), self.FrontendRESTServer) |
|
|
|
def server_bind(self): |
|
import socket |
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) |
|
super().server_bind() |
|
|
|
def server_activate(self): |
|
self.socket.listen(256) |
|
|
|
def process_request_thread(self, request, client_address): |
|
try: |
|
self.finish_request(request, client_address) |
|
except Exception: |
|
self.handle_error(request, client_address) |
|
finally: |
|
self.shutdown_request(request) |
|
|
|
def process_request(self, request, client_address): |
|
thread = threading.Thread(target=self.process_request_thread, args=(request, client_address), daemon=False) |
|
self._threads.append(thread) |
|
thread.start() |
|
|
|
def service_actions(self): |
|
if len(self._threads) >= self._args.max_concurrency: |
|
self._threads = [thread for thread in self._threads if thread.is_alive()] |
|
|
|
while len(self._threads) >= self._args.max_concurrency: |
|
time.sleep(0.1) |
|
self._threads = [thread for thread in self._threads if thread.is_alive()] |
|
|
|
def server_close(self): |
|
super().server_close() |
|
for thread in self._threads: |
|
thread.join() |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
import signal |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("port", type=int, help="Port to use") |
|
parser.add_argument("backends", type=str, nargs="+", help="Backends to use") |
|
parser.add_argument("--aliases", default=None, type=str, help="Path to model aliases") |
|
parser.add_argument("--logfile", default=None, type=str, help="Log path") |
|
parser.add_argument("--log_data", default=None, type=int, help="Log that much bytes of every request data") |
|
parser.add_argument("--max_concurrency", default=256, type=int, help="Maximum concurrency") |
|
parser.add_argument("--max_request_size", default=4096*1024, type=int, help="Maximum request size") |
|
args = parser.parse_args() |
|
|
|
|
|
if args.logfile is not None: |
|
sys.stderr = open(args.logfile, "a", encoding="utf-8") |
|
|
|
|
|
server = FrontendRESTServer(args) |
|
server_thread = threading.Thread(target=server.serve_forever, daemon=True) |
|
server_thread.start() |
|
|
|
print("Started Frontend REST server on port {}.".format(args.port), file=sys.stderr) |
|
print("To stop it gracefully, either send SIGINT (Ctrl+C) or SIGUSR1.", file=sys.stderr, flush=True) |
|
|
|
|
|
signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGUSR1]) |
|
signal.sigwait([signal.SIGINT, signal.SIGUSR1]) |
|
print("Initiating shutdown of the Frontend REST server.", file=sys.stderr, flush=True) |
|
server.shutdown() |
|
print("Stopped handling new requests, processing all current ones.", file=sys.stderr, flush=True) |
|
server.server_close() |
|
print("Finished shutdown of the Frontend REST server.", file=sys.stderr, flush=True) |
|
|