XciD HF staff commited on
Commit
30c8aac
·
1 Parent(s): 41ec8fd
detector-base.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c74935bd6568940038e6bfcc9c90bf821d7ae4163ebf2327b73db2f641376376
3
+ size 501001061
detector/index.html CHANGED
@@ -2,6 +2,7 @@
2
  <html>
3
  <head>
4
  <title>GPT-2 Output Detector</title>
 
5
  <style type="text/css">
6
  * {
7
  box-sizing: border-box;
@@ -74,7 +75,9 @@ em {
74
  <p>
75
  This is an online demo of the
76
  <a href="https://github.com/openai/gpt-2-output-dataset/tree/master/detector">GPT-2 output detector</a>
77
- model. Enter some text in the text box; the predicted probabilities will be displayed below.
 
 
78
  <u>The results start to get reliable after around 50 tokens.</u>
79
  </p>
80
  <textarea id="textbox" placeholder="Enter text here"></textarea>
@@ -134,7 +137,7 @@ textbox.oninput = () => {
134
  update_graph(null);
135
  return;
136
  }
137
- req.open('GET', '/?' + textbox.value, true);
138
  req.onreadystatechange = () => {
139
  if (req.readyState !== 4) return;
140
  if (req.status !== 200) throw new Error("HTTP status: " + req.status);
@@ -150,5 +153,15 @@ window.addEventListener('DOMContentLoaded', () => {
150
  textbox.focus();
151
  });
152
  </script>
 
 
 
 
 
 
 
 
 
 
153
  </body>
154
  </html>
 
2
  <html>
3
  <head>
4
  <title>GPT-2 Output Detector</title>
5
+ <meta charset="utf-8">
6
  <style type="text/css">
7
  * {
8
  box-sizing: border-box;
 
75
  <p>
76
  This is an online demo of the
77
  <a href="https://github.com/openai/gpt-2-output-dataset/tree/master/detector">GPT-2 output detector</a>
78
+ model, based on the <a href="https://github.com/huggingface/transformers/commit/1c542df7e554a2014051dd09becf60f157fed524"><code>🤗/Transformers</code></a>
79
+ implementation of <a href="https://arxiv.org/abs/1907.11692">RoBERTa</a>.
80
+ Enter some text in the text box; the predicted probabilities will be displayed below.
81
  <u>The results start to get reliable after around 50 tokens.</u>
82
  </p>
83
  <textarea id="textbox" placeholder="Enter text here"></textarea>
 
137
  update_graph(null);
138
  return;
139
  }
140
+ req.open('GET', window.location.href + '?' + textbox.value, true);
141
  req.onreadystatechange = () => {
142
  if (req.readyState !== 4) return;
143
  if (req.status !== 200) throw new Error("HTTP status: " + req.status);
 
153
  textbox.focus();
154
  });
155
  </script>
156
+ <script>
157
+ if (! ['localhost', 'huggingface.test'].includes(window.location.hostname)) {
158
+ (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
159
+ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
160
+ m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
161
+ })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
162
+ ga('create', 'UA-83738774-5', 'auto');
163
+ ga('send', 'pageview');
164
+ }
165
+ </script>
166
  </body>
167
  </html>
detector/server.py CHANGED
@@ -20,6 +20,30 @@ def log(*args):
20
 
21
  class RequestHandler(SimpleHTTPRequestHandler):
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def do_GET(self):
24
  query = unquote(urlparse(self.path).query)
25
 
@@ -32,6 +56,16 @@ class RequestHandler(SimpleHTTPRequestHandler):
32
 
33
  self.begin_content('application/json;charset=UTF-8')
34
 
 
 
 
 
 
 
 
 
 
 
35
  tokens = tokenizer.encode(query)
36
  all_tokens = len(tokens)
37
  tokens = tokens[:tokenizer.max_len - 2]
@@ -45,12 +79,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
45
 
46
  fake, real = probs.detach().cpu().flatten().numpy().tolist()
47
 
48
- self.wfile.write(json.dumps(dict(
49
- all_tokens=all_tokens,
50
- used_tokens=used_tokens,
51
- real_probability=real,
52
- fake_probability=fake
53
- )).encode())
54
 
55
  def begin_content(self, content_type):
56
  self.send_response(200)
@@ -118,3 +147,4 @@ def main(checkpoint, port=8080, device='cuda' if torch.cuda.is_available() else
118
 
119
  if __name__ == '__main__':
120
  fire.Fire(main)
 
 
20
 
21
  class RequestHandler(SimpleHTTPRequestHandler):
22
 
23
+ def do_POST(self):
24
+ self.begin_content('application/json,charset=UTF-8')
25
+
26
+ content_length = int(self.headers['Content-Length'])
27
+ if content_length > 0:
28
+ post_data = self.rfile.read(content_length).decode('utf-8')
29
+ try:
30
+ post_data = json.loads(post_data)
31
+
32
+ if 'text' not in post_data:
33
+ self.wfile.write(json.dumps({"error": "missing key 'text'"}).encode('utf-8'))
34
+ else:
35
+ all_tokens, used_tokens, fake, real = self.infer(post_data['text'])
36
+
37
+ self.wfile.write(json.dumps(dict(
38
+ all_tokens=all_tokens,
39
+ used_tokens=used_tokens,
40
+ real_probability=real,
41
+ fake_probability=fake
42
+ )).encode('utf-8'))
43
+
44
+ except Exception as e:
45
+ self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
46
+
47
  def do_GET(self):
48
  query = unquote(urlparse(self.path).query)
49
 
 
56
 
57
  self.begin_content('application/json;charset=UTF-8')
58
 
59
+ all_tokens, used_tokens, fake, real = self.infer(query)
60
+
61
+ self.wfile.write(json.dumps(dict(
62
+ all_tokens=all_tokens,
63
+ used_tokens=used_tokens,
64
+ real_probability=real,
65
+ fake_probability=fake
66
+ )).encode())
67
+
68
+ def infer(self, query):
69
  tokens = tokenizer.encode(query)
70
  all_tokens = len(tokens)
71
  tokens = tokens[:tokenizer.max_len - 2]
 
79
 
80
  fake, real = probs.detach().cpu().flatten().numpy().tolist()
81
 
82
+ return all_tokens, used_tokens, fake, real
 
 
 
 
 
83
 
84
  def begin_content(self, content_type):
85
  self.send_response(200)
 
147
 
148
  if __name__ == '__main__':
149
  fire.Fire(main)
150
+
detector/server_get.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from http.server import HTTPServer, SimpleHTTPRequestHandler
4
+ from multiprocessing import Process
5
+ import subprocess
6
+ from transformers import RobertaForSequenceClassification, RobertaTokenizer
7
+ import json
8
+ import fire
9
+ import torch
10
+ from urllib.parse import urlparse, unquote
11
+
12
+
13
+ model: RobertaForSequenceClassification = None
14
+ tokenizer: RobertaTokenizer = None
15
+ device: str = None
16
+
17
+ def log(*args):
18
+ print(f"[{os.environ.get('RANK', '')}]", *args, file=sys.stderr)
19
+
20
+
21
+ class RequestHandler(SimpleHTTPRequestHandler):
22
+
23
+ def do_GET(self):
24
+ query = unquote(urlparse(self.path).query)
25
+
26
+ if not query:
27
+ self.begin_content('text/html')
28
+
29
+ html = os.path.join(os.path.dirname(__file__), 'index.html')
30
+ self.wfile.write(open(html).read().encode())
31
+ return
32
+
33
+ self.begin_content('application/json;charset=UTF-8')
34
+
35
+ tokens = tokenizer.encode(query)
36
+ all_tokens = len(tokens)
37
+ tokens = tokens[:tokenizer.max_len - 2]
38
+ used_tokens = len(tokens)
39
+ tokens = torch.tensor([tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]).unsqueeze(0)
40
+ mask = torch.ones_like(tokens)
41
+
42
+ with torch.no_grad():
43
+ logits = model(tokens.to(device), attention_mask=mask.to(device))[0]
44
+ probs = logits.softmax(dim=-1)
45
+
46
+ fake, real = probs.detach().cpu().flatten().numpy().tolist()
47
+
48
+ self.wfile.write(json.dumps(dict(
49
+ all_tokens=all_tokens,
50
+ used_tokens=used_tokens,
51
+ real_probability=real,
52
+ fake_probability=fake
53
+ )).encode())
54
+
55
+ def begin_content(self, content_type):
56
+ self.send_response(200)
57
+ self.send_header('Content-Type', content_type)
58
+ self.send_header('Access-Control-Allow-Origin', '*')
59
+ self.end_headers()
60
+
61
+ def log_message(self, format, *args):
62
+ log(format % args)
63
+
64
+
65
+ def serve_forever(server, model, tokenizer, device):
66
+ log('Process has started; loading the model ...')
67
+ globals()['model'] = model.to(device)
68
+ globals()['tokenizer'] = tokenizer
69
+ globals()['device'] = device
70
+
71
+ log(f'Ready to serve at http://localhost:{server.server_address[1]}')
72
+ server.serve_forever()
73
+
74
+
75
+ def main(checkpoint, port=8080, device='cuda' if torch.cuda.is_available() else 'cpu'):
76
+ if checkpoint.startswith('gs://'):
77
+ print(f'Downloading {checkpoint}', file=sys.stderr)
78
+ subprocess.check_output(['gsutil', 'cp', checkpoint, '.'])
79
+ checkpoint = os.path.basename(checkpoint)
80
+ assert os.path.isfile(checkpoint)
81
+
82
+ print(f'Loading checkpoint from {checkpoint}')
83
+ data = torch.load(checkpoint, map_location='cpu')
84
+
85
+ model_name = 'roberta-large' if data['args']['large'] else 'roberta-base'
86
+ model = RobertaForSequenceClassification.from_pretrained(model_name)
87
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
88
+
89
+ model.load_state_dict(data['model_state_dict'])
90
+ model.eval()
91
+
92
+ print(f'Starting HTTP server on port {port}', file=sys.stderr)
93
+ server = HTTPServer(('0.0.0.0', port), RequestHandler)
94
+
95
+ # avoid calling CUDA API before forking; doing so in a subprocess is fine.
96
+ num_workers = int(subprocess.check_output([sys.executable, '-c', 'import torch; print(torch.cuda.device_count())']))
97
+
98
+ if num_workers <= 1:
99
+ serve_forever(server, model, tokenizer, device)
100
+ else:
101
+ print(f'Launching {num_workers} worker processes...')
102
+
103
+ subprocesses = []
104
+
105
+ for i in range(num_workers):
106
+ os.environ['RANK'] = f'{i}'
107
+ os.environ['CUDA_VISIBLE_DEVICES'] = f'{i}'
108
+ process = Process(target=serve_forever, args=(server, model, tokenizer, device))
109
+ process.start()
110
+ subprocesses.append(process)
111
+
112
+ del os.environ['RANK']
113
+ del os.environ['CUDA_VISIBLE_DEVICES']
114
+
115
+ for process in subprocesses:
116
+ process.join()
117
+
118
+
119
+ if __name__ == '__main__':
120
+ fire.Fire(main)