|
|
|
|
|
|
|
|
|
@@ -20,6 +20,22 @@ def log(*args): |
|
|
|
class RequestHandler(SimpleHTTPRequestHandler): |
|
|
|
+ def do_POST(self): |
|
+ content_length = int(self.headers['Content-Length']) # <--- Gets the size of data |
|
+ post_data = self.rfile.read(content_length) # <--- Gets the data itself |
|
+ post_data = json.loads(post_data.decode('utf-8')) |
|
+ |
|
+ self.begin_content('application/json;charset=UTF-8') |
|
+ |
|
+ all_tokens, used_tokens, real, fake = self.infer(post_data['text']) |
|
+ |
|
+ self.wfile.write(json.dumps(dict( |
|
+ all_tokens=all_tokens, |
|
+ used_tokens=used_tokens, |
|
+ real_probability=real, |
|
+ fake_probability=fake |
|
+ )).encode()) |
|
+ |
|
def do_GET(self): |
|
query = unquote(urlparse(self.path).query) |
|
|
|
@@ -32,6 +48,16 @@ class RequestHandler(SimpleHTTPRequestHandler): |
|
|
|
self.begin_content('application/json;charset=UTF-8') |
|
|
|
+ all_tokens, used_tokens, real, fake = self.infer(query) |
|
+ |
|
+ self.wfile.write(json.dumps(dict( |
|
+ all_tokens=all_tokens, |
|
+ used_tokens=used_tokens, |
|
+ real_probability=real, |
|
+ fake_probability=fake |
|
+ )).encode()) |
|
+ |
|
+ def infer(query): |
|
tokens = tokenizer.encode(query) |
|
all_tokens = len(tokens) |
|
tokens = tokens[:tokenizer.max_len - 2] |
|
@@ -45,12 +71,7 @@ class RequestHandler(SimpleHTTPRequestHandler): |
|
|
|
fake, real = probs.detach().cpu().flatten().numpy().tolist() |
|
|
|
- self.wfile.write(json.dumps(dict( |
|
- all_tokens=all_tokens, |
|
- used_tokens=used_tokens, |
|
- real_probability=real, |
|
- fake_probability=fake |
|
- )).encode()) |
|
+ return all_tokens, used_tokens, real, fake |
|
|
|
def begin_content(self, content_type): |
|
self.send_response(200) |