diff --git a/detector/server.py b/detector/server.py index 34a0c85..8c190e4 100644 --- a/detector/server.py +++ b/detector/server.py @@ -20,6 +20,22 @@ def log(*args): class RequestHandler(SimpleHTTPRequestHandler): + def do_POST(self): + content_length = int(self.headers['Content-Length']) # <--- Gets the size of data + post_data = self.rfile.read(content_length) # <--- Gets the data itself + post_data = json.loads(post_data.decode('utf-8')) + + self.begin_content('application/json;charset=UTF-8') + + all_tokens, used_tokens, real, fake = self.infer(post_data['text']) + + self.wfile.write(json.dumps(dict( + all_tokens=all_tokens, + used_tokens=used_tokens, + real_probability=real, + fake_probability=fake + )).encode()) + def do_GET(self): query = unquote(urlparse(self.path).query) @@ -32,6 +48,16 @@ class RequestHandler(SimpleHTTPRequestHandler): self.begin_content('application/json;charset=UTF-8') + all_tokens, used_tokens, real, fake = self.infer(query) + + self.wfile.write(json.dumps(dict( + all_tokens=all_tokens, + used_tokens=used_tokens, + real_probability=real, + fake_probability=fake + )).encode()) + + def infer(query): tokens = tokenizer.encode(query) all_tokens = len(tokens) tokens = tokens[:tokenizer.max_len - 2] @@ -45,12 +71,7 @@ class RequestHandler(SimpleHTTPRequestHandler): fake, real = probs.detach().cpu().flatten().numpy().tolist() - self.wfile.write(json.dumps(dict( - all_tokens=all_tokens, - used_tokens=used_tokens, - real_probability=real, - fake_probability=fake - )).encode()) + return all_tokens, used_tokens, real, fake def begin_content(self, content_type): self.send_response(200)