File size: 4,726 Bytes
be8596b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
# coding=utf-8
#
# Copyright 2020 Institute of Formal and Applied Linguistics, Faculty of
# Mathematics and Physics, Charles University, Czech Republic.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

"""Word embeddings server class."""

import http.server
import json
import socketserver
import os
import sys
import threading
import urllib.parse

import numpy as np

class WEmbeddingsServer(socketserver.ThreadingTCPServer):

    class WEmbeddingsRequestHandler(http.server.BaseHTTPRequestHandler):
        protocol_version = "HTTP/1.1"

        def respond(request, content_type, code=200):
            request.close_connection = True
            request.send_response(code)
            request.send_header("Connection", "close")
            request.send_header("Content-Type", content_type)
            request.send_header("Access-Control-Allow-Origin", "*")
            request.end_headers()

        def respond_error(request, message, code=400):
            request.respond("text/plain", code)
            request.wfile.write(message.encode("utf-8"))

        def do_POST(request):
            try:
                request.path = request.path.encode("iso-8859-1").decode("utf-8")
                url = urllib.parse.urlparse(request.path)
            except:
                return request.respond_error("Cannot parse request URL.")

            # Handle /wembeddings
            if url.path == "/wembeddings":
                if request.headers.get("Transfer-Encoding", "identity").lower() != "identity":
                    return request.respond_error("Only 'identity' Transfer-Encoding of payload is supported for now.")

                if "Content-Length" not in request.headers:
                    return request.respond_error("The Content-Length of payload is required.")

                try:
                    length = int(request.headers["Content-Length"])
                    data = json.loads(request.rfile.read(length))
                    model, sentences = data["model"], data["sentences"]
                except:
                    import traceback
                    traceback.print_exc(file=sys.stderr)
                    sys.stderr.flush()
                    return request.respond_error("Malformed request.")

                try:
                    with request.server._wembeddings_mutex:
                        sentences_embeddings = request.server._wembeddings.compute_embeddings(model, sentences)
                except:
                    import traceback
                    traceback.print_exc(file=sys.stderr)
                    sys.stderr.flush()
                    return request.respond_error("An error occurred during wembeddings computation.")

                request.respond("application/octet_stream")
                for sentence_embedding in sentences_embeddings:
                    np.lib.format.write_array(request.wfile, sentence_embedding.astype(request.server._dtype), allow_pickle=False)

            # URL not found
            else:
                request.respond_error("No handler for the given URL '{}'".format(url.path), code=404)

        def do_GET(request):
            try:
                request.path = request.path.encode("iso-8859-1").decode("utf-8")
                url = urllib.parse.urlparse(request.path)
            except:
                return request.respond_error("Cannot parse request URL.")

            if url.path == "/status":
                request.respond("application/json")
                request.wfile.write(bytes("""{"status": "UP"}""", "utf-8"))
            # URL not found
            else:
                request.respond_error("No handler for the given URL '{}'".format(url.path), code=404)

    daemon_threads = False

    def __init__(self, port, dtype, wembeddings_lambda):
        self._dtype = dtype

        # Create the WEmbeddings object its mutex
        self._wembeddings = wembeddings_lambda()
        self._wembeddings_mutex = threading.Lock()

        # Initialize the server
        super().__init__(("", port), self.WEmbeddingsRequestHandler)

    def server_bind(self):
        import socket
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        if os.name != 'nt':
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        super().server_bind()

    def service_actions(self):
        if isinstance(getattr(self, "_threads", None), list):
            if len(self._threads) >= 1024:
                self._threads = [thread for thread in self._threads if thread.is_alive()]