parser / wembedding_service /wembeddings /wembeddings_server.py
anasampa2's picture
Added wembedding_service folder.
be8596b verified
raw
history blame
No virus
4.73 kB
#!/usr/bin/env python3
# coding=utf-8
#
# Copyright 2020 Institute of Formal and Applied Linguistics, Faculty of
# Mathematics and Physics, Charles University, Czech Republic.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Word embeddings server class."""
import http.server
import json
import socketserver
import os
import sys
import threading
import urllib.parse
import numpy as np
class WEmbeddingsServer(socketserver.ThreadingTCPServer):
class WEmbeddingsRequestHandler(http.server.BaseHTTPRequestHandler):
protocol_version = "HTTP/1.1"
def respond(request, content_type, code=200):
request.close_connection = True
request.send_response(code)
request.send_header("Connection", "close")
request.send_header("Content-Type", content_type)
request.send_header("Access-Control-Allow-Origin", "*")
request.end_headers()
def respond_error(request, message, code=400):
request.respond("text/plain", code)
request.wfile.write(message.encode("utf-8"))
def do_POST(request):
try:
request.path = request.path.encode("iso-8859-1").decode("utf-8")
url = urllib.parse.urlparse(request.path)
except:
return request.respond_error("Cannot parse request URL.")
# Handle /wembeddings
if url.path == "/wembeddings":
if request.headers.get("Transfer-Encoding", "identity").lower() != "identity":
return request.respond_error("Only 'identity' Transfer-Encoding of payload is supported for now.")
if "Content-Length" not in request.headers:
return request.respond_error("The Content-Length of payload is required.")
try:
length = int(request.headers["Content-Length"])
data = json.loads(request.rfile.read(length))
model, sentences = data["model"], data["sentences"]
except:
import traceback
traceback.print_exc(file=sys.stderr)
sys.stderr.flush()
return request.respond_error("Malformed request.")
try:
with request.server._wembeddings_mutex:
sentences_embeddings = request.server._wembeddings.compute_embeddings(model, sentences)
except:
import traceback
traceback.print_exc(file=sys.stderr)
sys.stderr.flush()
return request.respond_error("An error occurred during wembeddings computation.")
request.respond("application/octet_stream")
for sentence_embedding in sentences_embeddings:
np.lib.format.write_array(request.wfile, sentence_embedding.astype(request.server._dtype), allow_pickle=False)
# URL not found
else:
request.respond_error("No handler for the given URL '{}'".format(url.path), code=404)
def do_GET(request):
try:
request.path = request.path.encode("iso-8859-1").decode("utf-8")
url = urllib.parse.urlparse(request.path)
except:
return request.respond_error("Cannot parse request URL.")
if url.path == "/status":
request.respond("application/json")
request.wfile.write(bytes("""{"status": "UP"}""", "utf-8"))
# URL not found
else:
request.respond_error("No handler for the given URL '{}'".format(url.path), code=404)
daemon_threads = False
def __init__(self, port, dtype, wembeddings_lambda):
self._dtype = dtype
# Create the WEmbeddings object its mutex
self._wembeddings = wembeddings_lambda()
self._wembeddings_mutex = threading.Lock()
# Initialize the server
super().__init__(("", port), self.WEmbeddingsRequestHandler)
def server_bind(self):
import socket
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if os.name != 'nt':
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
super().server_bind()
def service_actions(self):
if isinstance(getattr(self, "_threads", None), list):
if len(self._threads) >= 1024:
self._threads = [thread for thread in self._threads if thread.is_alive()]