XciD HF staff commited on
Commit
8969f81
1 Parent(s): f33fad2

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +1 -0
  2. Dockerfile +17 -0
  3. README.md +11 -1
  4. backend/API.py +165 -0
  5. backend/GPUHandler.py +159 -0
  6. backend/Model.py +299 -0
  7. backend/PPLM.py +723 -0
  8. backend/README.md +69 -0
  9. backend/Utils.py +57 -0
  10. backend/id/.gitignore +2 -0
  11. backend/id/id.ts +17 -0
  12. backend/id/package.json +12 -0
  13. backend/id/tsconfig.json +10 -0
  14. backend/id/yarn.lock +426 -0
  15. backend/install.sh +6 -0
  16. backend/launch.sh +50 -0
  17. backend/machine_configurations/neuralgenv2.json +13 -0
  18. backend/machine_configurations/transformer-autocomplete.json +11 -0
  19. backend/requirements.txt +4 -0
  20. backend/run_pplm_discrim_train.py +582 -0
  21. entrypoint.sh +7 -0
  22. front/.vscode/settings.json +7 -0
  23. front/assets/Icon-info.svg +9 -0
  24. front/assets/Salesforce_logo.svg +83 -0
  25. front/assets/Uber_logo.svg +11 -0
  26. front/assets/cross-collab.svg +14 -0
  27. front/assets/github-buttons.js +9 -0
  28. front/assets/huggingface_logo.svg +47 -0
  29. front/assets/icon-back.svg +15 -0
  30. front/assets/icon-publish.svg +16 -0
  31. front/assets/iconmonstr-download-14.svg +13 -0
  32. front/assets/iconmonstr-media-control-55.svg +13 -0
  33. front/assets/iconmonstr-share-11-purple.svg +10 -0
  34. front/assets/iconmonstr-share-11.svg +13 -0
  35. front/assets/oval.svg +17 -0
  36. front/assets/tail-spin.svg +32 -0
  37. front/assets/thumbnail-large-distilgpt2.png +0 -0
  38. front/assets/thumbnail-large-pplm.png +0 -0
  39. front/assets/thumbnail-large.png +0 -0
  40. front/assets/unicorn-tweaked.svg +1 -0
  41. front/favicon.ico +0 -0
  42. front/js-src/Api.ts +153 -0
  43. front/js-src/Mention.ts +441 -0
  44. front/js-src/controller.ts +319 -0
  45. front/js-src/lib/Log.ts +1 -0
  46. front/js-src/lib/Utils.ts +76 -0
  47. front/js-src/modals.ts +134 -0
  48. front/js-src/quill.d.ts +181 -0
  49. front/js-src/vanilla-tilt.ts +371 -0
  50. front/less/mixins/bfc.less +3 -0
.dockerignore ADDED
@@ -0,0 +1 @@
 
 
1
+ Dockerfile
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM node:14
2
+
3
+ RUN apt-get update && \
4
+ apt-get install -y nginx gettext-base && \
5
+ rm -rf /var/lib/apt/lists/* && \
6
+ chown -R 1000:1000 /etc/nginx && \
7
+ chown -R 1000:1000 /var/log/nginx && \
8
+ chown -R 1000:1000 /var/lib/nginx
9
+
10
+ WORKDIR /app/transformer-autocomplete
11
+ ADD . .
12
+
13
+ RUN cd front && npm install && npx tsc && npm run build:prod
14
+ RUN cd grunt && npm install && npx grunt
15
+ RUN cd server && npm install && npx tsc
16
+
17
+ ENTRYPOINT ["./entrypoint.sh"]
README.md CHANGED
@@ -4,7 +4,17 @@ emoji: 🏆
4
  colorFrom: green
5
  colorTo: gray
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
4
  colorFrom: green
5
  colorTo: gray
6
  sdk: docker
7
+ app_port: 8080
8
  pinned: false
9
  ---
10
 
11
+ # transformer-autocomplete
12
+
13
+ Autocompletion based on GPT-2
14
+
15
+ #### How to compile the front (to test the front with any server)
16
+
17
+ 1. Update the API endpoint in `front/js-src/Api.ts`
18
+ 2. compile the TS to pure JS with `cd front; tsc` or through vscode (you can launch it in watch mode if needed)
19
+ 3. pack the js into a single file (we use rollup) with `npm run watch`
20
+
backend/API.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ import falcon
3
+ from falcon.http_status import HTTPStatus
4
+ import json
5
+ import requests
6
+ import time
7
+ from Model import generate_completion
8
+ import sys
9
+
10
+
11
+ class AutoComplete(object):
12
+ def on_post(self, req, resp, single_endpoint=True, x=None, y=None):
13
+ json_data = json.loads(req.bounded_stream.read())
14
+
15
+ resp.status = falcon.HTTP_200
16
+
17
+ start = time.time()
18
+
19
+ try:
20
+ context = json_data["context"].rstrip()
21
+ except KeyError:
22
+ resp.body = "The context field is required"
23
+ resp.status = falcon.HTTP_422
24
+ return
25
+
26
+ try:
27
+ n_samples = json_data['samples']
28
+ except KeyError:
29
+ n_samples = 3
30
+
31
+ try:
32
+ length = json_data['gen_length']
33
+ except KeyError:
34
+ length = 20
35
+
36
+ try:
37
+ max_time = json_data['max_time']
38
+ except KeyError:
39
+ max_time = -1
40
+
41
+ try:
42
+ model_name = json_data['model_size']
43
+ except KeyError:
44
+ model_name = "small"
45
+
46
+ try:
47
+ temperature = json_data['temperature']
48
+ except KeyError:
49
+ temperature = 0.7
50
+
51
+ try:
52
+ max_tokens = json_data['max_tokens']
53
+ except KeyError:
54
+ max_tokens = 256
55
+
56
+ try:
57
+ top_p = json_data['top_p']
58
+ except KeyError:
59
+ top_p = 0.95
60
+
61
+ try:
62
+ top_k = json_data['top_k']
63
+ except KeyError:
64
+ top_k = 40
65
+
66
+
67
+ # CTRL
68
+ try:
69
+ repetition_penalty = json_data['repetition_penalty']
70
+ except KeyError:
71
+ repetition_penalty = 0.02
72
+
73
+ # PPLM
74
+ try:
75
+ stepsize = json_data['step_size']
76
+ except KeyError:
77
+ stepsize = 0.02
78
+
79
+ try:
80
+ gm_scale = json_data['gm_scale']
81
+ except KeyError:
82
+ gm_scale = None
83
+
84
+ try:
85
+ kl_scale = json_data['kl_scale']
86
+ except KeyError:
87
+ kl_scale = None
88
+
89
+ try:
90
+ num_iterations = json_data['num_iterations']
91
+ except KeyError:
92
+ num_iterations = None
93
+
94
+ try:
95
+ use_sampling = json_data['use_sampling']
96
+ except KeyError:
97
+ use_sampling = None
98
+
99
+ try:
100
+ bag_of_words_or_discrim = json_data['bow_or_discrim']
101
+ except KeyError:
102
+ bag_of_words_or_discrim = "kitchen"
103
+
104
+ print(json_data)
105
+
106
+ sentences = generate_completion(
107
+ context,
108
+ length=length,
109
+ max_time=max_time,
110
+ model_name=model_name,
111
+ temperature=temperature,
112
+ max_tokens=max_tokens,
113
+ top_p=top_p,
114
+ top_k=top_k,
115
+
116
+ # CTRL
117
+ repetition_penalty=repetition_penalty,
118
+
119
+ # PPLM
120
+ stepsize=stepsize,
121
+ bag_of_words_or_discrim=bag_of_words_or_discrim,
122
+ gm_scale=gm_scale,
123
+ kl_scale=kl_scale,
124
+ num_iterations=num_iterations,
125
+ use_sampling=use_sampling
126
+ )
127
+
128
+ resp.body = json.dumps({"sentences": sentences, 'time': time.time() - start})
129
+
130
+ resp.status = falcon.HTTP_200
131
+ sys.stdout.flush()
132
+
133
+
134
+ class Request(Thread):
135
+ def __init__(self, end_point, data):
136
+ Thread.__init__(self)
137
+ self.end_point = end_point
138
+ self.data = data
139
+ self.ret = None
140
+
141
+ def run(self):
142
+ print("Requesting with url", self.end_point)
143
+ self.ret = requests.post(url=self.end_point, json=self.data)
144
+
145
+ def join(self):
146
+ Thread.join(self)
147
+ return self.ret.text
148
+
149
+
150
+ class HandleCORS(object):
151
+ def process_request(self, req, resp):
152
+ resp.set_header('Access-Control-Allow-Origin', '*')
153
+ resp.set_header('Access-Control-Allow-Methods', '*')
154
+ resp.set_header('Access-Control-Allow-Headers', '*')
155
+ if req.method == 'OPTIONS':
156
+ raise HTTPStatus(falcon.HTTP_200, body='\n')
157
+
158
+
159
+ autocomplete = AutoComplete()
160
+ app = falcon.API(middleware=[HandleCORS()])
161
+ app.add_route('/autocomplete', autocomplete)
162
+ app.add_route('/autocomplete/{x}', autocomplete)
163
+ app.add_route('/autocomplete/{x}/{y}', autocomplete)
164
+
165
+ application = app
backend/GPUHandler.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config,
4
+ OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
5
+ XLNetLMHeadModel, XLNetTokenizer,
6
+ TransfoXLLMHeadModel, TransfoXLTokenizer,
7
+ CTRLLMHeadModel, CTRLTokenizer)
8
+
9
+ model_metadata = {
10
+ "gpt2/small": {
11
+ "tokenizer": GPT2Tokenizer,
12
+ "model": GPT2LMHeadModel,
13
+ "size": 550,
14
+ "checkpoint": "gpt2",
15
+ "identifier": "gpt2/small"
16
+ }, "gpt": {
17
+ "tokenizer": OpenAIGPTTokenizer,
18
+ "model": OpenAIGPTLMHeadModel,
19
+ "size": 550,
20
+ "checkpoint": "openai-gpt",
21
+ "identifier": "gpt"
22
+ }, "xlnet": {
23
+ "tokenizer": XLNetTokenizer,
24
+ "model": XLNetLMHeadModel,
25
+ "size": 550,
26
+ "checkpoint": "xlnet-base-cased",
27
+ "identifier": "xlnet"
28
+ }, "gpt2/arxiv-nlp": {
29
+ "tokenizer": GPT2Tokenizer,
30
+ "model": GPT2LMHeadModel,
31
+ "size": 550,
32
+ "checkpoint": "arxiv-nlp-v1",
33
+ "identifier": "gpt2/arxiv-nlp"
34
+ }, "gpt2/medium": {
35
+ "tokenizer": GPT2Tokenizer,
36
+ "model": GPT2LMHeadModel,
37
+ "size": 1500,
38
+ "checkpoint": "gpt2-medium",
39
+ "identifier": "gpt2/medium"
40
+ }, "gpt2/large": {
41
+ "tokenizer": GPT2Tokenizer,
42
+ "model": GPT2LMHeadModel,
43
+ "size": 3300,
44
+ "checkpoint": "gpt2-large",
45
+ "identifier": "gpt2/large"
46
+ }, "distilgpt2/small": {
47
+ "tokenizer": GPT2Tokenizer,
48
+ "model": GPT2LMHeadModel,
49
+ "size": 350,
50
+ "checkpoint": "distilgpt2",
51
+ "identifier": "distilgpt2/small"
52
+ }, "ctrl": {
53
+ "tokenizer": CTRLTokenizer,
54
+ "model": CTRLLMHeadModel,
55
+ "size": 6300,
56
+ "checkpoint": "ctrl",
57
+ "identifier": "ctrl"
58
+ }, "pplm": {
59
+ "tokenizer": GPT2Tokenizer,
60
+ "model": GPT2LMHeadModel,
61
+ "size": 3000,
62
+ "checkpoint": "gpt2-large",
63
+ "identifier": "pplm"
64
+ }, "gpt2/xl": {
65
+ "tokenizer": GPT2Tokenizer,
66
+ "model": GPT2LMHeadModel,
67
+ "size": 7000,
68
+ "checkpoint": "gpt2-xl",
69
+ "identifier": "gpt2/xl"
70
+ }, "pplm": {
71
+ "tokenizer": GPT2Tokenizer,
72
+ "model": GPT2LMHeadModel,
73
+ "size": 4000,
74
+ "checkpoint": "gpt2-medium",
75
+ "identifier": "pplm",
76
+ "configuration_options": {
77
+ "config": GPT2Config,
78
+ "options": {
79
+ "output_hidden_states": True
80
+ }
81
+ }
82
+ }
83
+ }
84
+
85
+ memory_overhead = 500
86
+
87
+ class GPU:
88
+ def __init__(self, id):
89
+ self.id = id
90
+ self.models = []
91
+ self.total_memory = torch.cuda.get_device_properties(
92
+ "cuda:{}".format(id)).total_memory / 1_000_000 - 1_000
93
+
94
+ print("INIT GPU WITH DEVICE", "cuda:{}".format(id))
95
+
96
+ def register_model(self, model, cached_path=None):
97
+ if self.total_memory_used() + model["size"] < self.total_memory:
98
+ model["device"] = "cuda:{}".format(self.id)
99
+
100
+ if cached_path:
101
+ model["cached_path"] = cached_path
102
+
103
+ self.models.append(model)
104
+ return True
105
+ else:
106
+ return False
107
+
108
+ def total_memory_used(self):
109
+ return sum([model["size"] for model in self.models]) + memory_overhead
110
+
111
+ def __repr__(self):
112
+ return str(
113
+ [(model["checkpoint"], model["size"]) for model in self.models] +
114
+ [str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] +
115
+ ["cuda:{}".format(self.id)]
116
+ )
117
+
118
+
119
+ class GPUHandler:
120
+ def __init__(self, ids, model_list, gpu_ids, cached_models=None):
121
+ if cached_models is None:
122
+ cached_models = {}
123
+
124
+ self.gpus = [GPU(id) for id in gpu_ids]
125
+ print("GPU handler initiated with {} gpus.".format(len(self.gpus)))
126
+
127
+ self.sanity_check([model_metadata[model] for model in model_list])
128
+
129
+ for model in model_list:
130
+ self.register_model(model_metadata[model], cached_models.get(model))
131
+
132
+ def register_model(self, model, cached_path=None):
133
+ for index, gpu in enumerate(self.gpus):
134
+ if gpu.register_model(model, cached_path):
135
+ print("Registered model", model, "in GPU", gpu)
136
+ break
137
+
138
+ if index >= len(self.gpus):
139
+ raise ValueError("Could not load model", model["checkpoint"])
140
+
141
+ def sanity_check(self, model_list):
142
+ temp_gpus = [GPU(id) for id in range(len(self.gpus))]
143
+
144
+ for model in model_list:
145
+
146
+ current_gpu_index = 0
147
+ while current_gpu_index < len(temp_gpus):
148
+ if not temp_gpus[current_gpu_index].register_model(model):
149
+ current_gpu_index += 1
150
+ else:
151
+ break
152
+
153
+ if current_gpu_index >= len(temp_gpus):
154
+ raise RuntimeError("SANITY CHECK FAILED")
155
+
156
+ print("Current layout", temp_gpus)
157
+
158
+ def __repr__(self):
159
+ return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}"
backend/Model.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from transformers import (GPT2LMHeadModel, GPT2Tokenizer,
3
+ OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
4
+ XLNetLMHeadModel, XLNetTokenizer,
5
+ TransfoXLLMHeadModel, TransfoXLTokenizer,
6
+ CTRLLMHeadModel, CTRLTokenizer)
7
+
8
+ from Utils import forward, create_context
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from math import floor
12
+ import requests
13
+ import json
14
+ import os
15
+ from PPLM import run_model as run_pplm, DISCRIMINATOR_MODELS_PARAMS
16
+ from GPUHandler import GPUHandler
17
+
18
+ PADDING_TEXT = """With eyes for the most part downcast and, if ever they lighted on a fellow creature, at once and
19
+ furtively averted, Bernard hastened across the roof. He was like a man pursued, but pursued by enemies he does not
20
+ wish to see, lest they should seem more hostile even than he had supposed, and he himself be made to feel guiltier
21
+ and even more helplessly alone. That horrible Benito Hoover!’ And yet the man had meant well enough. Which only made
22
+ it, in a way, much worse. Those who meant well behaved in the same way as those who meant badly. Even Lenina was making
23
+ him suffer. He remembered those weeks of timid indecision, during which he had looked and longed and despaired of ever
24
+ having the courage to ask her. Dared he face the risk of being humiliated by a contemptuous refusal? But if she were to
25
+ say yes, what rapture! Well, now she had said it and he was still wretched—wretched that she should have thought it
26
+ such a perfect afternoon for Obstacle Golf, that she should have trotted away to join Henry Foster, that she should
27
+ have found him funny for not wanting to talk of their most private affairs in public. Wretched, in a word, because she
28
+ had behaved as any healthy and virtuous English girl ought to behave and not in some other, abnormal, extraordinary
29
+ way. <eod> </s> <eos>"""
30
+
31
+ try:
32
+ PID = int(requests.get(url="http://localhost:3000").json())
33
+ N_GPU = torch.cuda.device_count()
34
+ GPU_PER_WORKER = int(os.getenv("GPU_PER_WORKER"))
35
+ GPU_IDS = list(range(PID * GPU_PER_WORKER, (PID + 1) * GPU_PER_WORKER))
36
+ print("Successfully init thread with id {}. The GPU ids attributed are: {}".format(PID, GPU_IDS))
37
+
38
+ with open(os.getenv("FILE")) as json_file:
39
+ data = json.load(json_file)
40
+ models = data["models_to_load"]
41
+ cached_models = data.get("cached_models")
42
+ except requests.exceptions.ConnectionError or TypeError:
43
+ if __name__ == "__main__":
44
+ PID = 0
45
+ N_GPU = torch.cuda.device_count()
46
+ GPU_PER_WORKER = 1
47
+ GPU_IDS = [0]
48
+ print("Successfully init development thread with id {}. The GPU ids attributed are: {}".format(PID, GPU_IDS))
49
+ models = ["pplm"]
50
+ cached_models = None
51
+ pass
52
+ else:
53
+ raise requests.exceptions.ConnectionError("The PID server is not running.")
54
+
55
+
56
+ handler = GPUHandler(int(), models, GPU_IDS, cached_models)
57
+ models = {}
58
+
59
+ for gpu in handler.gpus:
60
+ for model in gpu.models:
61
+ model_name = model["identifier"]
62
+ print(f"Loading {model_name} model and tokenizer")
63
+ models[model_name] = model
64
+
65
+ if model.get("cached_path"):
66
+ print("Loading {} from local path.".format(model_name))
67
+ model_checkpoint_path = model["cached_path"]
68
+ else:
69
+ model_checkpoint_path = model["checkpoint"]
70
+
71
+ if "configuration_options" in models[model_name]:
72
+ configuration_options = models[model_name]["configuration_options"]
73
+ print("Specific configuration options", configuration_options["options"])
74
+
75
+ config = configuration_options["config"].from_pretrained(model_checkpoint_path)
76
+
77
+ for option_key, option_value in configuration_options["options"].items():
78
+ setattr(config, option_key, option_value)
79
+
80
+ models[model_name]["model"] = models[model_name]["model"].from_pretrained(model_checkpoint_path, config=config).to(models[model_name]["device"])
81
+ else:
82
+ models[model_name]["model"] = models[model_name]["model"].from_pretrained(model_checkpoint_path).to(models[model_name]["device"])
83
+
84
+ models[model_name]["tokenizer"] = models[model_name]["tokenizer"].from_pretrained(models[model_name]["checkpoint"])
85
+ models[model_name]["model"].eval()
86
+
87
+ print("All models successfully loaded.")
88
+
89
+
90
+ def top_k_top_p_filtering(batch_logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
91
+ """
92
+ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
93
+
94
+ :param batch_logits: logits output by the model
95
+ :param top_k: >0: keep only top k tokens with highest probability (top-k filtering).
96
+ :param top_p: >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
97
+ :param filter_value:
98
+ :return: A top_p/top_k filtered tensor of logits
99
+ """
100
+
101
+ for i in range(batch_logits.size(0)):
102
+ logits = batch_logits[i]
103
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
104
+ top_k = min(top_k, logits.size(-1)) # Safety check
105
+ if top_k and top_k > 0:
106
+ # Remove all tokens with a probability less than the last token of the top-k
107
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
108
+ logits[indices_to_remove] = filter_value
109
+
110
+ if top_p and top_p > 0.0:
111
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
112
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
113
+
114
+ # Remove tokens with cumulative probability above the threshold
115
+ sorted_indices_to_remove = cumulative_probs > top_p
116
+ # Shift the indices to the right to keep also the first token above the threshold
117
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
118
+ sorted_indices_to_remove[..., 0] = 0
119
+
120
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
121
+ logits[indices_to_remove] = filter_value
122
+
123
+ if 'batched_logits' in locals():
124
+ batched_logits = torch.cat((batched_logits, logits.unsqueeze(0)), dim=0)
125
+ else:
126
+ batched_logits = logits.unsqueeze(0)
127
+
128
+ return batched_logits
129
+
130
+
131
+ def check_tensor_for_eot(output, eot_token, dot_token):
132
+ return all([(eot_token in output_item or dot_token in output_item) for output_item in output.tolist()])
133
+
134
+
135
+ def truncate_after_eot(output, eot_tokens):
136
+ result = []
137
+ for i in range(output.size(0)):
138
+ if any([eot_token in output[i] for eot_token in eot_tokens]):
139
+ item = output[i].tolist()
140
+ index = find_min_value_in_array(item, eot_tokens)
141
+ result.append(item[:index] + [eot_tokens[0]])
142
+ else:
143
+ result.append(output[i].tolist())
144
+ return result
145
+
146
+
147
+ def find_min_value_in_array(array, values):
148
+ indexes = []
149
+ for value in values:
150
+ try:
151
+ indexes.append(array.index(value))
152
+ except ValueError:
153
+ "" # Couldn't find value in array
154
+
155
+ return min(indexes)
156
+
157
+
158
+ # @lru_cache()
159
+ def generate_completion(
160
+ raw_text,
161
+ length=-1,
162
+ max_time=-1,
163
+ model_name="small",
164
+ temperature=1,
165
+ max_tokens=256,
166
+ top_p=0.0,
167
+ top_k=0,
168
+ batch_size=3,
169
+ repetition_penalty=1.2,
170
+
171
+ # PPLM
172
+ bag_of_words_or_discrim=None,
173
+ stepsize=0.02,
174
+ gamma=1.5,
175
+ num_iterations=3,
176
+ window_length=5,
177
+ kl_scale=0.01,
178
+ gm_scale=0.95,
179
+ use_sampling=False
180
+ ):
181
+ start = time.time()
182
+
183
+ try:
184
+ print("Running with model", model_name)
185
+ model, tokenizer, device = models[model_name]["model"], models[model_name]["tokenizer"], models[model_name]["device"]
186
+ except KeyError:
187
+ print("Error. Defaulting to small model.")
188
+ model, tokenizer, device = models["gpt2/small"]["model"], models["gpt2/small"]["tokenizer"], models["gpt2/small"]["device"]
189
+
190
+ if "pplm" in model_name:
191
+ if ":" in bag_of_words_or_discrim:
192
+ discrim, discrim_label = bag_of_words_or_discrim.split(":")
193
+ discrim_label = DISCRIMINATOR_MODELS_PARAMS[discrim]["class_id"][int(discrim_label)]
194
+ bag_of_words = None
195
+
196
+ # Hardcoded parameters for the discriminator
197
+ gamma = 1.0
198
+
199
+ print("Running PPLM with discriminator:", discrim, discrim_label)
200
+ else:
201
+ bag_of_words = bag_of_words_or_discrim
202
+ discrim = None
203
+ discrim_label = None
204
+
205
+ # Hardcoded parameters for the BOW
206
+ gamma = 1.5
207
+ window_length = 5
208
+
209
+ print("Running PPLM with bag of words:", bag_of_words)
210
+
211
+ print("kl", kl_scale, "gm", gm_scale, "sampling", use_sampling, "window length", window_length, "gamma", gamma, "temperature", temperature)
212
+
213
+ return run_pplm(
214
+ model, tokenizer, device, raw_text,
215
+ max_time=max_time,
216
+ discrim=discrim,
217
+ discrim_label=discrim_label,
218
+ num_samples=batch_size,
219
+ bag_of_words=bag_of_words,
220
+ length=length,
221
+ temperature=temperature,
222
+ top_k=top_k,
223
+ stepsize=stepsize,
224
+ gamma=gamma,
225
+ num_iterations=num_iterations,
226
+ window_length=window_length,
227
+ kl_scale=kl_scale,
228
+ gm_scale=gm_scale,
229
+ use_sampling=use_sampling
230
+ )
231
+
232
+
233
+ context_tokens, eot_token, dot_token = create_context(model_name, tokenizer, raw_text, PADDING_TEXT, max_tokens=max_tokens)
234
+
235
+ if length == -1:
236
+ length = 100
237
+
238
+ context = torch.tensor(context_tokens, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
239
+ prev = context
240
+ past = None
241
+
242
+ with torch.no_grad():
243
+ for _ in range(length):
244
+ try:
245
+ output = forward(model_name, model, prev, past, device=device)
246
+ except RuntimeError:
247
+ return "ERROR 500: OOM. TransfoXL asked for too much memory."
248
+
249
+ logits, past = output if len(output) > 2 else output[0], None
250
+
251
+ logits = logits[:, -1, :] / max(temperature, 0.001)
252
+
253
+ if "ctrl" in model_name:
254
+ for i in range(batch_size):
255
+ for j in set(prev[i].tolist()):
256
+ logits[i, j] /= repetition_penalty
257
+
258
+ logits = top_k_top_p_filtering(logits, top_p=top_p, top_k=top_k)
259
+ log_probs = F.softmax(logits, dim=-1)
260
+ token = torch.multinomial(log_probs, num_samples=1)
261
+
262
+ prev = torch.cat((prev, token), dim=1)
263
+
264
+ # Check that there is no eot token in all of the sentence, else breaks.
265
+ if check_tensor_for_eot(prev[:, len(context_tokens):], eot_token, dot_token) or (max_time != -1 and time.time() - start + 0.1 > max_time):
266
+ break
267
+
268
+ out = prev[:, len(context_tokens):]
269
+ # Remove the words following the eot tokens.
270
+ out = truncate_after_eot(out, list(filter(lambda t: t is not None, [dot_token, eot_token])))
271
+ end = time.time()
272
+
273
+ # Remove empty sentences and duplicates
274
+ generations = list(set(filter(lambda x: len(x) > 0, [" " + tokenizer.decode(single_generation).strip() for single_generation in out])))
275
+
276
+ sentences = [
277
+ {"value": generations[i], "time": end - start, "tokens": len(out[i])} for i in range(len(generations))
278
+ ]
279
+
280
+
281
+ # print(end - start, [len(out[i]) for i in range(len(generations))])
282
+
283
+ return sentences
284
+
285
+
286
+ if __name__ == "__main__":
287
+ print(generate_completion(
288
+ "My dog died",
289
+ length=30, model_name="pplm", batch_size=3, top_k=10, top_p=0.9,
290
+ bag_of_words_or_discrim="sentiment:2",
291
+ stepsize=0.03,
292
+ gamma=1,
293
+ num_iterations=3,
294
+ window_length=5,
295
+ kl_scale=0.01,
296
+ gm_scale=0.95,
297
+ max_time=-1,
298
+ use_sampling=False
299
+ ))
backend/PPLM.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # coding=utf-8
3
+ # Copyright 2018 The Uber AI Team Authors.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ Example command with bag of words:
19
+ python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
20
+
21
+ Example command with discriminator:
22
+ python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
23
+ """
24
+
25
+ import json
26
+ from operator import add
27
+ from typing import List, Optional, Tuple, Union
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ from torch.autograd import Variable
33
+ from tqdm import trange
34
+ from transformers.file_utils import cached_path
35
+ import time
36
+
37
+ from run_pplm_discrim_train import ClassificationHead
38
+
39
+ PPLM_BOW = 1
40
+ PPLM_DISCRIM = 2
41
+ PPLM_BOW_DISCRIM = 3
42
+ SMALL_CONST = 1e-15
43
+ BIG_CONST = 1e10
44
+
45
+ BAG_OF_WORDS_ARCHIVE_MAP = {
46
+ 'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt",
47
+ 'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
48
+ 'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
49
+ 'monsters': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/monsters.txt",
50
+ 'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
51
+ 'positive_words': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/positive_words.txt",
52
+ 'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
53
+ 'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
54
+ 'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
55
+ 'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
56
+ }
57
+
58
+ DISCRIMINATOR_MODELS_PARAMS = {
59
+ "clickbait": {
60
+ "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifierhead.pt",
61
+ "class_size": 2,
62
+ "embed_size": 1024,
63
+ "class_vocab": {"non_clickbait": 0, "clickbait": 1},
64
+ "class_id": {0: "non_clickbait", 1: "clickbait"},
65
+ "default_class": 1,
66
+ "pretrained_model": "gpt2-medium",
67
+ },
68
+ "sentiment": {
69
+ "url": "http://s.yosinski.com/SST_classifier_head.pt",
70
+ "class_size": 5,
71
+ "embed_size": 1024,
72
+ "class_vocab": {"very_positive": 2, "very_negative": 3},
73
+ "class_id": {2: "very_positive", 3: "very_negative"},
74
+ "default_class": 3,
75
+ "pretrained_model": "gpt2-medium",
76
+ },
77
+ "toxicity": {
78
+ "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt",
79
+ "class_size": 2,
80
+ "embed_size": 1024,
81
+ "class_vocab": {"non_toxic": 0, "toxic": 1},
82
+ "class_id": {0: "non_toxic", 1: "toxic"},
83
+ "default_class": 0,
84
+ "pretrained_model": "gpt2-medium",
85
+ },
86
+ }
87
+
88
+
89
+ def to_var(x, requires_grad=False, volatile=False, device='cuda'):
90
+ if torch.cuda.is_available() and device == 'cuda':
91
+ x = x.cuda()
92
+ elif device != 'cuda':
93
+ x = x.to(device)
94
+ return Variable(x, requires_grad=requires_grad, volatile=volatile)
95
+
96
+
97
+ def top_k_filter(logits, k, probs=False):
98
+ """
99
+ Masks everything but the k top entries as -infinity (1e10).
100
+ Used to mask logits such that e^-infinity -> 0 won't contribute to the
101
+ sum of the denominator.
102
+ """
103
+ if k == 0:
104
+ return logits
105
+ else:
106
+ values = torch.topk(logits, k)[0]
107
+ batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
108
+ if probs:
109
+ return torch.where(logits < batch_mins,
110
+ torch.ones_like(logits) * 0.0, logits)
111
+ return torch.where(logits < batch_mins,
112
+ torch.ones_like(logits) * -BIG_CONST,
113
+ logits)
114
+
115
+
116
+ def perturb_past(
117
+ past,
118
+ model,
119
+ last,
120
+ unpert_past=None,
121
+ unpert_logits=None,
122
+ accumulated_hidden=None,
123
+ grad_norms=None,
124
+ stepsize=0.01,
125
+ one_hot_bows_vectors=None,
126
+ classifier=None,
127
+ class_label=None,
128
+ loss_type=0,
129
+ num_iterations=3,
130
+ horizon_length=1,
131
+ window_length=0,
132
+ decay=False,
133
+ gamma=1.5,
134
+ kl_scale=0.01,
135
+ device='cuda',
136
+ ):
137
+ # Generate inital perturbed past
138
+ grad_accumulator = [
139
+ (np.zeros(p.shape).astype("float32"))
140
+ for p in past
141
+ ]
142
+
143
+ if accumulated_hidden is None:
144
+ accumulated_hidden = 0
145
+
146
+ if decay:
147
+ decay_mask = torch.arange(
148
+ 0.,
149
+ 1.0 + SMALL_CONST,
150
+ 1.0 / (window_length)
151
+ )[1:]
152
+ else:
153
+ decay_mask = 1.0
154
+
155
+ # TODO fix this comment (SUMANTH)
156
+ # Generate a mask is gradient perturbated is based on a past window
157
+ _, batch_size, _, curr_length, _ = past[0].shape
158
+
159
+ if curr_length > window_length and window_length > 0:
160
+ ones_key_val_shape = (
161
+ tuple(past[0].shape[:-2])
162
+ + tuple([window_length])
163
+ + tuple(past[0].shape[-1:])
164
+ )
165
+
166
+ zeros_key_val_shape = (
167
+ tuple(past[0].shape[:-2])
168
+ + tuple([curr_length - window_length])
169
+ + tuple(past[0].shape[-1:])
170
+ )
171
+
172
+ ones_mask = torch.ones(ones_key_val_shape)
173
+ ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
174
+ ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
175
+
176
+ window_mask = torch.cat(
177
+ (ones_mask, torch.zeros(zeros_key_val_shape)),
178
+ dim=-2
179
+ ).to(device)
180
+ else:
181
+ window_mask = torch.ones_like(past[0]).to(device)
182
+
183
+ # accumulate perturbations for num_iterations
184
+ loss_per_iter = []
185
+ losses_per_iter = []
186
+ new_accumulated_hidden = None
187
+ for i in range(num_iterations):
188
+ curr_perturbation = [
189
+ to_var(torch.from_numpy(p_), requires_grad=True, device=device)
190
+ for p_ in grad_accumulator
191
+ ]
192
+
193
+ # Compute hidden using perturbed past
194
+ perturbed_past = list(map(add, past, curr_perturbation))
195
+ _, _, _, curr_length, _ = curr_perturbation[0].shape
196
+ all_logits, _, all_hidden = model(last, past=perturbed_past)
197
+ hidden = all_hidden[-1]
198
+ new_accumulated_hidden = accumulated_hidden + torch.sum(
199
+ hidden,
200
+ dim=1
201
+ ).detach()
202
+ # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
203
+ logits = all_logits[:, -1, :]
204
+ probs = F.softmax(logits, dim=-1)
205
+
206
+ loss = 0.0
207
+ losses = torch.zeros(batch_size, device=device)
208
+ loss_list = []
209
+ if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
210
+ for one_hot_bow in one_hot_bows_vectors:
211
+ bow_logits = torch.mm(probs, torch.t(one_hot_bow))
212
+ bow_losses = -torch.log(torch.sum(bow_logits, dim=-1))
213
+ losses += bow_losses
214
+ bow_loss = torch.sum(bow_losses) # sum over batches
215
+ loss += bow_loss
216
+ loss_list.append(bow_loss)
217
+
218
+ if loss_type == 2 or loss_type == 3:
219
+ ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
220
+ # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
221
+ curr_unpert_past = unpert_past
222
+ curr_probs = torch.unsqueeze(probs, dim=1)
223
+ wte = model.resize_token_embeddings()
224
+ for _ in range(horizon_length):
225
+ inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
226
+ _, curr_unpert_past, curr_all_hidden = model(
227
+ past=curr_unpert_past,
228
+ inputs_embeds=inputs_embeds
229
+ )
230
+ curr_hidden = curr_all_hidden[-1]
231
+ new_accumulated_hidden = new_accumulated_hidden + torch.sum(
232
+ curr_hidden, dim=1)
233
+
234
+ prediction = classifier(new_accumulated_hidden /
235
+ (curr_length + 1 + horizon_length))
236
+
237
+ label = torch.tensor(batch_size * [class_label],
238
+ device=device,
239
+ dtype=torch.long)
240
+ discrim_losses = ce_loss(prediction, label)
241
+ losses += discrim_losses
242
+ discrim_loss = discrim_losses.sum(-1)
243
+ loss += discrim_loss
244
+ loss_list.append(discrim_loss)
245
+
246
+ kl_loss = 0.0
247
+ if kl_scale > 0.0:
248
+ unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
249
+ unpert_probs = (
250
+ unpert_probs + SMALL_CONST *
251
+ (unpert_probs <= SMALL_CONST).float().to(device).detach()
252
+ )
253
+ correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
254
+ device).detach()
255
+ corrected_probs = probs + correction.detach()
256
+ kl_losses = kl_scale * (
257
+ (corrected_probs * (corrected_probs / unpert_probs).log()).sum(-1)
258
+ )
259
+ losses += kl_losses
260
+ kl_loss = kl_losses.sum()
261
+ loss += kl_loss
262
+
263
+ loss_per_iter.append(loss.data.cpu().numpy())
264
+ losses_per_iter.append(losses.data.cpu().numpy())
265
+
266
+ # compute gradients
267
+ loss.backward()
268
+
269
+ # calculate gradient norms
270
+ if grad_norms is not None and loss_type == PPLM_BOW:
271
+ grad_norms = [
272
+ torch.max(grad_norms[index],
273
+ torch.norm_except_dim(p_.grad * window_mask, dim=1))
274
+ #torch.norm(p_.grad * window_mask))
275
+ for index, p_ in enumerate(curr_perturbation)
276
+ ]
277
+ else:
278
+ grad_norms = [
279
+ (torch.norm_except_dim(p_.grad * window_mask, dim=1) + SMALL_CONST)
280
+ for index, p_ in enumerate(curr_perturbation)
281
+ ]
282
+
283
+ # normalize gradients
284
+ grad = [
285
+ -stepsize *
286
+ (p_.grad * window_mask / grad_norms[
287
+ index] ** gamma).data.cpu().numpy()
288
+ for index, p_ in enumerate(curr_perturbation)
289
+ ]
290
+
291
+ # accumulate gradient
292
+ grad_accumulator = list(map(add, grad, grad_accumulator))
293
+
294
+ # reset gradients, just to make sure
295
+ for p_ in curr_perturbation:
296
+ p_.grad.data.zero_()
297
+
298
+ # removing past from the graph
299
+ new_past = []
300
+ for p_ in past:
301
+ new_past.append(p_.detach())
302
+ past = new_past
303
+
304
+ # apply the accumulated perturbations to the past
305
+ grad_accumulator = [
306
+ to_var(torch.from_numpy(p_), requires_grad=True, device=device)
307
+ for p_ in grad_accumulator
308
+ ]
309
+ pert_past = list(map(add, past, grad_accumulator))
310
+
311
+ return pert_past, new_accumulated_hidden, grad_norms, losses_per_iter
312
+
313
+
314
+ def get_classifier(
315
+ name: Optional[str], class_label: Union[str, int],
316
+ device: str
317
+ ) -> Tuple[Optional[ClassificationHead], Optional[int]]:
318
+ if name is None:
319
+ return None, None
320
+
321
+ params = DISCRIMINATOR_MODELS_PARAMS[name]
322
+ classifier = ClassificationHead(
323
+ class_size=params['class_size'],
324
+ embed_size=params['embed_size']
325
+ ).to(device)
326
+ if "url" in params:
327
+ resolved_archive_file = cached_path(params["url"])
328
+ elif "path" in params:
329
+ resolved_archive_file = params["path"]
330
+ else:
331
+ raise ValueError("Either url or path have to be specified "
332
+ "in the discriminator model parameters")
333
+ classifier.load_state_dict(
334
+ torch.load(resolved_archive_file, map_location=device))
335
+ classifier.eval()
336
+
337
+ if isinstance(class_label, str):
338
+ if class_label in params["class_vocab"]:
339
+ label_id = params["class_vocab"][class_label]
340
+ else:
341
+ label_id = params["default_class"]
342
+
343
+
344
+ elif isinstance(class_label, int):
345
+ if class_label in set(params["class_vocab"].values()):
346
+ label_id = class_label
347
+ else:
348
+ label_id = params["default_class"]
349
+
350
+ else:
351
+ label_id = params["default_class"]
352
+
353
+ return classifier, label_id
354
+
355
+
356
+ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
357
+ List[List[List[int]]]:
358
+ bow_indices = []
359
+ for id_or_path in bag_of_words_ids_or_paths:
360
+ if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
361
+ filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])
362
+ else:
363
+ filepath = id_or_path
364
+ with open(filepath, "r") as f:
365
+ words = f.read().strip().split("\n")
366
+ bow_indices.append(
367
+ [tokenizer.encode(word.strip(), add_prefix_space=True,
368
+ add_special_tokens=False) for word in
369
+ words])
370
+ return bow_indices
371
+
372
+
373
+ def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
374
+ if bow_indices is None:
375
+ return None
376
+
377
+ one_hot_bows_vectors = []
378
+ for single_bow in bow_indices:
379
+ single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
380
+ single_bow = torch.tensor(single_bow).to(device)
381
+ num_words = single_bow.shape[0]
382
+ one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
383
+ one_hot_bow.scatter_(1, single_bow, 1)
384
+ one_hot_bows_vectors.append(one_hot_bow)
385
+ return one_hot_bows_vectors
386
+
387
+
388
+ def full_text_generation(
389
+ model,
390
+ tokenizer,
391
+ context=None,
392
+ num_samples=1,
393
+ device="cuda",
394
+ max_time=5,
395
+ sample=False,
396
+ discrim=None,
397
+ class_label=None,
398
+ bag_of_words=None,
399
+ length=100,
400
+ grad_length=10000,
401
+ stepsize=0.02,
402
+ num_iterations=3,
403
+ temperature=1.0,
404
+ gm_scale=0.9,
405
+ kl_scale=0.01,
406
+ top_k=10,
407
+ window_length=0,
408
+ horizon_length=1,
409
+ decay=False,
410
+ gamma=1.5,
411
+ ):
412
+ classifier, class_id = get_classifier(
413
+ discrim,
414
+ class_label,
415
+ device
416
+ )
417
+
418
+ bow_indices = []
419
+ if bag_of_words:
420
+ bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
421
+ tokenizer)
422
+
423
+ if bag_of_words and classifier:
424
+ loss_type = PPLM_BOW_DISCRIM
425
+
426
+ elif bag_of_words:
427
+ loss_type = PPLM_BOW
428
+
429
+ elif classifier is not None:
430
+ loss_type = PPLM_DISCRIM
431
+
432
+ else:
433
+ raise Exception("Specify either a bag of words or a discriminator")
434
+
435
+ # unpert_gen_tok_text = generate_text_pplm(
436
+ # model=model,
437
+ # tokenizer=tokenizer,
438
+ # context=context,
439
+ # device=device,
440
+ # length=length,
441
+ # perturb=False
442
+ # )
443
+ # if device == 'cuda':
444
+ # torch.cuda.empty_cache()
445
+
446
+ print(context, bow_indices, top_k, gm_scale, kl_scale)
447
+
448
+ pert_gen_tok_text, last_losses = generate_text_pplm(
449
+ model=model,
450
+ context=context,
451
+ tokenizer=tokenizer,
452
+ device=device,
453
+ max_time=max_time,
454
+ sample=sample,
455
+ perturb=True,
456
+ bow_indices=bow_indices,
457
+ classifier=classifier,
458
+ class_label=class_id,
459
+ loss_type=loss_type,
460
+ length=length,
461
+ grad_length=grad_length,
462
+ stepsize=stepsize,
463
+ num_iterations=num_iterations,
464
+ temperature=temperature,
465
+ gm_scale=gm_scale,
466
+ kl_scale=kl_scale,
467
+ top_k=top_k,
468
+ window_length=window_length,
469
+ horizon_length=horizon_length,
470
+ decay=decay,
471
+ gamma=gamma,
472
+ )
473
+
474
+ if device == 'cuda':
475
+ torch.cuda.empty_cache()
476
+
477
+ return pert_gen_tok_text, last_losses
478
+
479
+
480
+ def generate_text_pplm(
481
+ model,
482
+ tokenizer,
483
+ context=None,
484
+ past=None,
485
+ device="cuda",
486
+ max_time=5,
487
+ perturb=True,
488
+ bow_indices=None,
489
+ classifier=None,
490
+ class_label=None,
491
+ loss_type=0,
492
+ length=100,
493
+ stepsize=0.02,
494
+ temperature=1.0,
495
+ top_k=10,
496
+ sample=False,
497
+ num_iterations=3,
498
+ grad_length=10000,
499
+ horizon_length=1,
500
+ window_length=0,
501
+ decay=False,
502
+ gamma=1.5,
503
+ gm_scale=0.9,
504
+ kl_scale=0.01,
505
+ ):
506
+ output_so_far = None
507
+ if context:
508
+ context_t = torch.tensor(context, device=device, dtype=torch.long)
509
+ while len(context_t.shape) < 2:
510
+ context_t = context_t.unsqueeze(0)
511
+ output_so_far = context_t
512
+
513
+ # collect one hot vectors for bags of words
514
+ one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
515
+ device)
516
+
517
+ start = time.time()
518
+
519
+ grad_norms = None
520
+ last = None
521
+ losses_this_iter = None
522
+ losses_in_time = []
523
+ for i in trange(length, ascii=True):
524
+
525
+ # Get past/probs for current output, except for last word
526
+ # Note that GPT takes 2 inputs: past + current_token
527
+
528
+ # run model forward to obtain unperturbed
529
+ if past is None and output_so_far is not None:
530
+ last = output_so_far[:, -1:]
531
+ if output_so_far.shape[1] > 1:
532
+ _, past, _ = model(output_so_far[:, :-1])
533
+
534
+ unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
535
+ unpert_last_hidden = unpert_all_hidden[-1]
536
+
537
+ # check if we are abowe grad max length
538
+ if i >= grad_length:
539
+ current_stepsize = stepsize * 0
540
+ else:
541
+ current_stepsize = stepsize
542
+
543
+ # modify the past if necessary
544
+ if not perturb or num_iterations == 0:
545
+ pert_past = past
546
+
547
+ else:
548
+ accumulated_hidden = unpert_last_hidden[:, :-1, :]
549
+ accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
550
+
551
+ if past is not None:
552
+ pert_past, _, grad_norms, losses_this_iter = perturb_past(
553
+ past,
554
+ model,
555
+ last,
556
+ unpert_past=unpert_past,
557
+ unpert_logits=unpert_logits,
558
+ accumulated_hidden=accumulated_hidden,
559
+ grad_norms=grad_norms,
560
+ stepsize=current_stepsize,
561
+ one_hot_bows_vectors=one_hot_bows_vectors,
562
+ classifier=classifier,
563
+ class_label=class_label,
564
+ loss_type=loss_type,
565
+ num_iterations=num_iterations,
566
+ horizon_length=horizon_length,
567
+ window_length=window_length,
568
+ decay=decay,
569
+ gamma=gamma,
570
+ kl_scale=kl_scale,
571
+ device=device,
572
+ )
573
+ losses_in_time.append(losses_this_iter)
574
+ else:
575
+ pert_past = past
576
+
577
+ pert_logits, past, pert_all_hidden = model(last, past=pert_past)
578
+ pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
579
+ pert_probs = F.softmax(pert_logits, dim=-1)
580
+
581
+ # Fuse the modified model and original model
582
+ if perturb:
583
+
584
+ unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
585
+
586
+ pert_probs = ((pert_probs ** gm_scale) * (
587
+ unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
588
+ pert_probs = top_k_filter(pert_probs, k=top_k,
589
+ probs=True) # + SMALL_CONST
590
+
591
+ # rescale
592
+ if torch.sum(pert_probs) <= 1:
593
+ pert_probs = pert_probs / torch.sum(pert_probs)
594
+
595
+ else:
596
+ pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
597
+ pert_probs = F.softmax(pert_logits, dim=-1)
598
+
599
+ # sample or greedy
600
+ if sample:
601
+ last = torch.multinomial(pert_probs, num_samples=1)
602
+
603
+ else:
604
+ _, last = torch.topk(pert_probs, k=1, dim=-1)
605
+
606
+ # update context/output_so_far appending the new token
607
+ output_so_far = (
608
+ last if output_so_far is None
609
+ else torch.cat((output_so_far, last), dim=1)
610
+ )
611
+
612
+ if time.time() - start > max_time and max_time != -1:
613
+ break
614
+
615
+ final_losses = losses_this_iter[-1] if losses_this_iter else None
616
+ return output_so_far, final_losses
617
+
618
+
619
+ def set_generic_model_params(discrim_weights, discrim_meta):
620
+ if discrim_weights is None:
621
+ raise ValueError('When using a generic discriminator, '
622
+ 'discrim_weights need to be specified')
623
+ if discrim_meta is None:
624
+ raise ValueError('When using a generic discriminator, '
625
+ 'discrim_meta need to be specified')
626
+
627
+ with open(discrim_meta, 'r') as discrim_meta_file:
628
+ meta = json.load(discrim_meta_file)
629
+ meta['path'] = discrim_weights
630
+ DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
631
+
632
+
633
+ def run_model(
634
+ model,
635
+ tokenizer,
636
+ device,
637
+ raw_text,
638
+ max_time,
639
+ bag_of_words=None,
640
+ discrim=None,
641
+ discrim_weights=None,
642
+ discrim_meta=None,
643
+ discrim_label=-1,
644
+ stepsize=0.02,
645
+ length=10,
646
+ seed=None,
647
+ temperature=1.0,
648
+ top_k=10,
649
+ gm_scale=0.9,
650
+ kl_scale=0.01,
651
+ uncond=False,
652
+ num_iterations=3,
653
+ grad_length=10000,
654
+ num_samples=1,
655
+ horizon_length=1,
656
+ window_length=0,
657
+ decay=False,
658
+ gamma=1.5,
659
+ use_sampling=False
660
+ ):
661
+ print(seed)
662
+ if seed is not None:
663
+ # set Random seed
664
+ torch.manual_seed(seed)
665
+ np.random.seed(seed)
666
+
667
+ if discrim == 'generic':
668
+ set_generic_model_params(discrim_weights, discrim_meta)
669
+
670
+ tokenized_cond_text = [tokenizer.encode(
671
+ tokenizer.bos_token + raw_text, max_length=512 - length - 1)] * num_samples
672
+
673
+ # Freeze GPT-2 weights
674
+ for param in model.parameters():
675
+ param.requires_grad = False
676
+
677
+ # generate unperturbed and perturbed texts
678
+
679
+ # full_text_generation returns:
680
+ # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
681
+
682
+ pert_gen_tok_text, last_losses = full_text_generation(
683
+ model=model,
684
+ tokenizer=tokenizer,
685
+ context=tokenized_cond_text,
686
+ device=device,
687
+ max_time=max_time,
688
+ num_samples=num_samples,
689
+ discrim=discrim,
690
+ class_label=discrim_label,
691
+ bag_of_words=bag_of_words,
692
+ length=length,
693
+ grad_length=grad_length,
694
+ stepsize=stepsize,
695
+ num_iterations=num_iterations,
696
+ temperature=temperature,
697
+ gm_scale=gm_scale,
698
+ kl_scale=kl_scale,
699
+ top_k=top_k,
700
+ window_length=window_length,
701
+ horizon_length=horizon_length,
702
+ decay=decay,
703
+ gamma=gamma,
704
+ sample=use_sampling
705
+ )
706
+
707
+ generated_texts = []
708
+
709
+ # iterate through the perturbed texts
710
+ for sample, loss in zip(pert_gen_tok_text.tolist(), last_losses.tolist()):
711
+ generated_part = sample[len(tokenized_cond_text[0]):]
712
+ pert_gen_text = tokenizer.decode(generated_part)
713
+
714
+ # keep the prefix, perturbed seq, original seq for each index
715
+ generated_texts.append(
716
+ {
717
+ "value": pert_gen_text,
718
+ "tokens": len(generated_part),
719
+ "loss": loss
720
+ }
721
+ )
722
+
723
+ return generated_texts
backend/README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python backend
2
+
3
+ ## Setup
4
+
5
+ ```
6
+ pip install -r requirements.txt
7
+ chmod +x launch.sh
8
+ ```
9
+
10
+ ## Execution
11
+
12
+
13
+ `./launch.sh`
14
+
15
+ ## Usage
16
+
17
+ The API listens to the port `6006` and the route `autocomplete`. It listens to `POST` requests.
18
+ Query it like this: `{POST}http://<url>:6006/autocomplete`
19
+
20
+ The necessary argument is `context` which is a string of characters (ideally a sentence) which will be converted in tokens and fed to GPT-2.
21
+
22
+ The optional arguments are detailed below:
23
+
24
+ `length` is an unsigned int which sets the maximum length (in tokens) of the generated sentence __default: 100__
25
+
26
+ `n_samples` is an int `0 < n_samples <= 3` which sets the maximum amount of samples generated. __default: 3__
27
+
28
+ `max_time` is an unsigned float which sets an heuristic for the maximum time spent generating sentences. It is a heuristic because it is not exact, it can slightly overflow. __default: infinite__
29
+
30
+ `model_size` takes `"small"` or `"medium"` as input and corresponds to the GPT model size __default: small__
31
+
32
+ `temperature` float - temperature of the model __default: 1__
33
+
34
+ `max_tokens` int - maximum amount of tokens that will be fed into the model. __default: 256__
35
+
36
+ `top_p` float - 0 < top_p < 1, nucleus sampling; only tokens with a cumulative probability of top_p will be selected for multinomial sampling __default: 0.9__
37
+
38
+ `top_k` int - Only top k tokens will be selected for multinomial sampling. __default: 256__
39
+
40
+ ## Return format
41
+
42
+ The server returns a set of sentences according to the context. Their format is:
43
+ ```
44
+ {sentences: {value: string, time: number}[], time: number}
45
+ ```
46
+
47
+ Example:
48
+
49
+ With POST parameters as:
50
+
51
+ ```json
52
+ {
53
+ "context": "That man is just another",
54
+ "samples": 3
55
+ }
56
+ ```
57
+
58
+ The response is as follows:
59
+
60
+ ```json
61
+ {
62
+ "sentences": [
63
+ {"value": " handicapped working man.", "time": 0.15415167808532715},
64
+ {"value": " guy, doing everything his manly nature requires.", "time": 0.2581148147583008},
65
+ {"value": " guy, Mohr said.", "time": 0.17547011375427246}
66
+ ],
67
+ "time": 0.264873743057251
68
+ }
69
+ ```
backend/Utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def forward(model_name, model, input_ids, past, device='cpu'):
5
+ if "gpt2" in model_name or "ctrl" in model_name:
6
+ if past is not None:
7
+ return model(input_ids[:, -1], past=past)
8
+ return model(input_ids)
9
+ elif "xlnet" in model_name:
10
+ input_ids = torch.cat((
11
+ input_ids,
12
+ torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=device)
13
+ ), dim=1)
14
+
15
+ perm_mask = torch.zeros(
16
+ (input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
17
+ dtype=torch.float,
18
+ device=device
19
+ )
20
+ perm_mask[:, :, -1] = 1.0
21
+
22
+ target_mapping = torch.zeros(
23
+ (input_ids.shape[0], 1, input_ids.shape[1]),
24
+ dtype=torch.float,
25
+ device=device)
26
+ target_mapping[:, 0, -1] = 1.0
27
+
28
+ return model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
29
+ elif "transfo-xl" in model_name:
30
+ return model(input_ids, mems=past)
31
+ else:
32
+ return model(input_ids)
33
+
34
+
35
+ def create_context(model_name, tokenizer, initial_text="", padding_text=None, max_tokens=512):
36
+ if not len(initial_text) and "gpt2" in model_name:
37
+ initial_text = "<|endoftext|>"
38
+ if 'xlnet' in model_name or "transfo-xl" in model_name:
39
+ initial_text = padding_text + initial_text
40
+
41
+ if 'transfo-xl' in model_name:
42
+ max_tokens = int(max_tokens / 2)
43
+
44
+ context_tokens = tokenizer.encode(initial_text)[-max_tokens:]
45
+
46
+ if "gpt2" in model_name:
47
+ eot_token = tokenizer.encoder["<|endoftext|>"]
48
+ if len(context_tokens) == 0:
49
+ context_tokens = [tokenizer.encoder["<|endoftext|>"]]
50
+ elif "xlnet" in model_name:
51
+ eot_token = tokenizer.convert_tokens_to_ids('<eop>')
52
+ else:
53
+ eot_token = None
54
+ dot_token = tokenizer.encode(".")[-1]
55
+
56
+ return context_tokens, eot_token, dot_token
57
+
backend/id/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ /node_modules
2
+ /dist
backend/id/id.ts ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as Koa from "koa";
2
+ import * as Router from "koa-router";
3
+
4
+ let id = 0;
5
+
6
+ const app = new Koa();
7
+ const router = new Router();
8
+
9
+ router.get("/*", async ctx => {
10
+ ctx.body = id++;
11
+ });
12
+
13
+ app.use(router.routes());
14
+
15
+ app.listen(3000);
16
+
17
+ console.log("Server running on port 3000");
backend/id/package.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dependencies": {
3
+ "@types/koa": "^2.0.48",
4
+ "@types/koa-router": "^7.0.40",
5
+ "koa": "^2.7.0",
6
+ "koa-router": "^7.4.0",
7
+ "typescript": "^3.5.1"
8
+ },
9
+ "scripts": {
10
+ "start": "tsc && node dist/id.js"
11
+ }
12
+ }
backend/id/tsconfig.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "target": "esnext",
4
+ "module": "commonjs",
5
+ "outDir": "dist/",
6
+ "strictNullChecks": true,
7
+ "strict": true,
8
+ "lib": ["esnext", "dom", "es6", "es2016", "es2017", "es2018"]
9
+ }
10
+ }
backend/id/yarn.lock ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
2
+ # yarn lockfile v1
3
+
4
+
5
+ "@types/accepts@*":
6
+ version "1.3.5"
7
+ resolved "https://registry.yarnpkg.com/@types/accepts/-/accepts-1.3.5.tgz#c34bec115cfc746e04fe5a059df4ce7e7b391575"
8
+ integrity sha512-jOdnI/3qTpHABjM5cx1Hc0sKsPoYCp+DP/GJRGtDlPd7fiV9oXGGIcjW/ZOxLIvjGz8MA+uMZI9metHlgqbgwQ==
9
+ dependencies:
10
+ "@types/node" "*"
11
+
12
+ "@types/body-parser@*":
13
+ version "1.17.0"
14
+ resolved "https://registry.yarnpkg.com/@types/body-parser/-/body-parser-1.17.0.tgz#9f5c9d9bd04bb54be32d5eb9fc0d8c974e6cf58c"
15
+ integrity sha512-a2+YeUjPkztKJu5aIF2yArYFQQp8d51wZ7DavSHjFuY1mqVgidGyzEQ41JIVNy82fXj8yPgy2vJmfIywgESW6w==
16
+ dependencies:
17
+ "@types/connect" "*"
18
+ "@types/node" "*"
19
+
20
+ "@types/connect@*":
21
+ version "3.4.32"
22
+ resolved "https://registry.yarnpkg.com/@types/connect/-/connect-3.4.32.tgz#aa0e9616b9435ccad02bc52b5b454ffc2c70ba28"
23
+ integrity sha512-4r8qa0quOvh7lGD0pre62CAb1oni1OO6ecJLGCezTmhQ8Fz50Arx9RUszryR8KlgK6avuSXvviL6yWyViQABOg==
24
+ dependencies:
25
+ "@types/node" "*"
26
+
27
+ "@types/cookies@*":
28
+ version "0.7.2"
29
+ resolved "https://registry.yarnpkg.com/@types/cookies/-/cookies-0.7.2.tgz#5e0560d46ed9998082dce799af1058dd6a49780a"
30
+ integrity sha512-jnihWgshWystcJKrz8C9hV+Ot9lqOUyAh2RF+o3BEo6K6AS2l4zYCb9GYaBuZ3C6Il59uIGqpE3HvCun4KKeJA==
31
+ dependencies:
32
+ "@types/connect" "*"
33
+ "@types/express" "*"
34
+ "@types/keygrip" "*"
35
+ "@types/node" "*"
36
+
37
+ "@types/express-serve-static-core@*":
38
+ version "4.16.7"
39
+ resolved "https://registry.yarnpkg.com/@types/express-serve-static-core/-/express-serve-static-core-4.16.7.tgz#50ba6f8a691c08a3dd9fa7fba25ef3133d298049"
40
+ integrity sha512-847KvL8Q1y3TtFLRTXcVakErLJQgdpFSaq+k043xefz9raEf0C7HalpSY7OW5PyjCnY8P7bPW5t/Co9qqp+USg==
41
+ dependencies:
42
+ "@types/node" "*"
43
+ "@types/range-parser" "*"
44
+
45
+ "@types/express@*":
46
+ version "4.17.0"
47
+ resolved "https://registry.yarnpkg.com/@types/express/-/express-4.17.0.tgz#49eaedb209582a86f12ed9b725160f12d04ef287"
48
+ integrity sha512-CjaMu57cjgjuZbh9DpkloeGxV45CnMGlVd+XpG7Gm9QgVrd7KFq+X4HY0vM+2v0bczS48Wg7bvnMY5TN+Xmcfw==
49
+ dependencies:
50
+ "@types/body-parser" "*"
51
+ "@types/express-serve-static-core" "*"
52
+ "@types/serve-static" "*"
53
+
54
+ "@types/http-assert@*":
55
+ version "1.4.0"
56
+ resolved "https://registry.yarnpkg.com/@types/http-assert/-/http-assert-1.4.0.tgz#41d173466e396e99a14d75f7160cc997f2f9ed8b"
57
+ integrity sha512-TZDqvFW4nQwL9DVSNJIJu4lPLttKgzRF58COa7Vs42Ki/MrhIqUbeIw0MWn4kGLiZLXB7oCBibm7nkSjPkzfKQ==
58
+
59
+ "@types/keygrip@*":
60
+ version "1.0.1"
61
+ resolved "https://registry.yarnpkg.com/@types/keygrip/-/keygrip-1.0.1.tgz#ff540462d2fb4d0a88441ceaf27d287b01c3d878"
62
+ integrity sha1-/1QEYtL7TQqIRBzq8n0oewHD2Hg=
63
+
64
+ "@types/koa-compose@*":
65
+ version "3.2.4"
66
+ resolved "https://registry.yarnpkg.com/@types/koa-compose/-/koa-compose-3.2.4.tgz#76a461634a59c3e13449831708bb9b355fb1548e"
67
+ integrity sha512-ioou0rxkuWL+yBQYsHUQAzRTfVxAg8Y2VfMftU+Y3RA03/MzuFL0x/M2sXXj3PkfnENbHsjeHR1aMdezLYpTeA==
68
+ dependencies:
69
+ "@types/koa" "*"
70
+
71
+ "@types/koa-router@^7.0.40":
72
+ version "7.0.40"
73
+ resolved "https://registry.yarnpkg.com/@types/koa-router/-/koa-router-7.0.40.tgz#9654dbc43375a0380c44c49c4504b4dbfc3e4e6a"
74
+ integrity sha512-YK4+WGXch6Ig9PreZ9jlHZb2onm0S1szGw0oQxWvPhoyjSHo1Tq+CpjxMmthEUIQUc9KznOGgehFarOx8XwsFw==
75
+ dependencies:
76
+ "@types/koa" "*"
77
+
78
+ "@types/koa@*", "@types/koa@^2.0.48":
79
+ version "2.0.48"
80
+ resolved "https://registry.yarnpkg.com/@types/koa/-/koa-2.0.48.tgz#29162783029d3e5df8b58c55f6bf0d35f78fc39f"
81
+ integrity sha512-CiIUYhHlOFJhSCTmsFoFkV2t9ij1JwW26nt0W9XZoWTvmAw6zTE0+k3IAoGICtjzIfhZpZcO323NHmI1LGmdDw==
82
+ dependencies:
83
+ "@types/accepts" "*"
84
+ "@types/cookies" "*"
85
+ "@types/http-assert" "*"
86
+ "@types/keygrip" "*"
87
+ "@types/koa-compose" "*"
88
+ "@types/node" "*"
89
+
90
+ "@types/mime@*":
91
+ version "2.0.1"
92
+ resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.1.tgz#dc488842312a7f075149312905b5e3c0b054c79d"
93
+ integrity sha512-FwI9gX75FgVBJ7ywgnq/P7tw+/o1GUbtP0KzbtusLigAOgIgNISRK0ZPl4qertvXSIE8YbsVJueQ90cDt9YYyw==
94
+
95
+ "@types/node@*":
96
+ version "12.0.7"
97
+ resolved "https://registry.yarnpkg.com/@types/node/-/node-12.0.7.tgz#4f2563bad652b2acb1722d7e7aae2b0ff62d192c"
98
+ integrity sha512-1YKeT4JitGgE4SOzyB9eMwO0nGVNkNEsm9qlIt1Lqm/tG2QEiSMTD4kS3aO6L+w5SClLVxALmIBESK6Mk5wX0A==
99
+
100
+ "@types/range-parser@*":
101
+ version "1.2.3"
102
+ resolved "https://registry.yarnpkg.com/@types/range-parser/-/range-parser-1.2.3.tgz#7ee330ba7caafb98090bece86a5ee44115904c2c"
103
+ integrity sha512-ewFXqrQHlFsgc09MK5jP5iR7vumV/BYayNC6PgJO2LPe8vrnNFyjQjSppfEngITi0qvfKtzFvgKymGheFM9UOA==
104
+
105
+ "@types/serve-static@*":
106
+ version "1.13.2"
107
+ resolved "https://registry.yarnpkg.com/@types/serve-static/-/serve-static-1.13.2.tgz#f5ac4d7a6420a99a6a45af4719f4dcd8cd907a48"
108
+ integrity sha512-/BZ4QRLpH/bNYgZgwhKEh+5AsboDBcUdlBYgzoLX0fpj3Y2gp6EApyOlM3bK53wQS/OE1SrdSYBAbux2D1528Q==
109
+ dependencies:
110
+ "@types/express-serve-static-core" "*"
111
+ "@types/mime" "*"
112
+
113
+ accepts@^1.3.5:
114
+ version "1.3.7"
115
+ resolved "https://registry.yarnpkg.com/accepts/-/accepts-1.3.7.tgz#531bc726517a3b2b41f850021c6cc15eaab507cd"
116
+ integrity sha512-Il80Qs2WjYlJIBNzNkK6KYqlVMTbZLXgHx2oT0pU/fjRHyEp+PEfEPY0R3WCwAGVOtauxh1hOxNgIf5bv7dQpA==
117
+ dependencies:
118
+ mime-types "~2.1.24"
119
+ negotiator "0.6.2"
120
+
121
+ any-promise@^1.1.0:
122
+ version "1.3.0"
123
+ resolved "https://registry.yarnpkg.com/any-promise/-/any-promise-1.3.0.tgz#abc6afeedcea52e809cdc0376aed3ce39635d17f"
124
+ integrity sha1-q8av7tzqUugJzcA3au0845Y10X8=
125
+
126
+ cache-content-type@^1.0.0:
127
+ version "1.0.1"
128
+ resolved "https://registry.yarnpkg.com/cache-content-type/-/cache-content-type-1.0.1.tgz#035cde2b08ee2129f4a8315ea8f00a00dba1453c"
129
+ integrity sha512-IKufZ1o4Ut42YUrZSo8+qnMTrFuKkvyoLXUywKz9GJ5BrhOFGhLdkx9sG4KAnVvbY6kEcSFjLQul+DVmBm2bgA==
130
+ dependencies:
131
+ mime-types "^2.1.18"
132
+ ylru "^1.2.0"
133
+
134
+ co@^4.6.0:
135
+ version "4.6.0"
136
+ resolved "https://registry.yarnpkg.com/co/-/co-4.6.0.tgz#6ea6bdf3d853ae54ccb8e47bfa0bf3f9031fb184"
137
+ integrity sha1-bqa989hTrlTMuOR7+gvz+QMfsYQ=
138
+
139
+ content-disposition@~0.5.2:
140
+ version "0.5.3"
141
+ resolved "https://registry.yarnpkg.com/content-disposition/-/content-disposition-0.5.3.tgz#e130caf7e7279087c5616c2007d0485698984fbd"
142
+ integrity sha512-ExO0774ikEObIAEV9kDo50o+79VCUdEB6n6lzKgGwupcVeRlhrj3qGAfwq8G6uBJjkqLrhT0qEYFcWng8z1z0g==
143
+ dependencies:
144
+ safe-buffer "5.1.2"
145
+
146
+ content-type@^1.0.4:
147
+ version "1.0.4"
148
+ resolved "https://registry.yarnpkg.com/content-type/-/content-type-1.0.4.tgz#e138cc75e040c727b1966fe5e5f8c9aee256fe3b"
149
+ integrity sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA==
150
+
151
+ cookies@~0.7.1:
152
+ version "0.7.3"
153
+ resolved "https://registry.yarnpkg.com/cookies/-/cookies-0.7.3.tgz#7912ce21fbf2e8c2da70cf1c3f351aecf59dadfa"
154
+ integrity sha512-+gixgxYSgQLTaTIilDHAdlNPZDENDQernEMiIcZpYYP14zgHsCt4Ce1FEjFtcp6GefhozebB6orvhAAWx/IS0A==
155
+ dependencies:
156
+ depd "~1.1.2"
157
+ keygrip "~1.0.3"
158
+
159
+ debug@^3.1.0:
160
+ version "3.2.6"
161
+ resolved "https://registry.yarnpkg.com/debug/-/debug-3.2.6.tgz#e83d17de16d8a7efb7717edbe5fb10135eee629b"
162
+ integrity sha512-mel+jf7nrtEl5Pn1Qx46zARXKDpBbvzezse7p7LqINmdoIk8PYP5SySaxEmYv6TZ0JyEKA1hsCId6DIhgITtWQ==
163
+ dependencies:
164
+ ms "^2.1.1"
165
+
166
+ debug@~3.1.0:
167
+ version "3.1.0"
168
+ resolved "https://registry.yarnpkg.com/debug/-/debug-3.1.0.tgz#5bb5a0672628b64149566ba16819e61518c67261"
169
+ integrity sha512-OX8XqP7/1a9cqkxYw2yXss15f26NKWBpDXQd0/uK/KPqdQhxbPa994hnzjcE2VqQpDslf55723cKPUOGSmMY3g==
170
+ dependencies:
171
+ ms "2.0.0"
172
+
173
+ deep-equal@~1.0.1:
174
+ version "1.0.1"
175
+ resolved "https://registry.yarnpkg.com/deep-equal/-/deep-equal-1.0.1.tgz#f5d260292b660e084eff4cdbc9f08ad3247448b5"
176
+ integrity sha1-9dJgKStmDghO/0zbyfCK0yR0SLU=
177
+
178
+ delegates@^1.0.0:
179
+ version "1.0.0"
180
+ resolved "https://registry.yarnpkg.com/delegates/-/delegates-1.0.0.tgz#84c6e159b81904fdca59a0ef44cd870d31250f9a"
181
+ integrity sha1-hMbhWbgZBP3KWaDvRM2HDTElD5o=
182
+
183
+ depd@^1.1.2, depd@~1.1.2:
184
+ version "1.1.2"
185
+ resolved "https://registry.yarnpkg.com/depd/-/depd-1.1.2.tgz#9bcd52e14c097763e749b274c4346ed2e560b5a9"
186
+ integrity sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=
187
+
188
+ destroy@^1.0.4:
189
+ version "1.0.4"
190
+ resolved "https://registry.yarnpkg.com/destroy/-/destroy-1.0.4.tgz#978857442c44749e4206613e37946205826abd80"
191
+ integrity sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA=
192
+
193
+ ee-first@1.1.1:
194
+ version "1.1.1"
195
+ resolved "https://registry.yarnpkg.com/ee-first/-/ee-first-1.1.1.tgz#590c61156b0ae2f4f0255732a158b266bc56b21d"
196
+ integrity sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0=
197
+
198
+ error-inject@^1.0.0:
199
+ version "1.0.0"
200
+ resolved "https://registry.yarnpkg.com/error-inject/-/error-inject-1.0.0.tgz#e2b3d91b54aed672f309d950d154850fa11d4f37"
201
+ integrity sha1-4rPZG1Su1nLzCdlQ0VSFD6EdTzc=
202
+
203
+ escape-html@^1.0.3:
204
+ version "1.0.3"
205
+ resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988"
206
+ integrity sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg=
207
+
208
+ fresh@~0.5.2:
209
+ version "0.5.2"
210
+ resolved "https://registry.yarnpkg.com/fresh/-/fresh-0.5.2.tgz#3d8cadd90d976569fa835ab1f8e4b23a105605a7"
211
+ integrity sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac=
212
+
213
+ http-assert@^1.3.0:
214
+ version "1.4.1"
215
+ resolved "https://registry.yarnpkg.com/http-assert/-/http-assert-1.4.1.tgz#c5f725d677aa7e873ef736199b89686cceb37878"
216
+ integrity sha512-rdw7q6GTlibqVVbXr0CKelfV5iY8G2HqEUkhSk297BMbSpSL8crXC+9rjKoMcZZEsksX30le6f/4ul4E28gegw==
217
+ dependencies:
218
+ deep-equal "~1.0.1"
219
+ http-errors "~1.7.2"
220
+
221
+ http-errors@^1.3.1, http-errors@^1.6.3, http-errors@~1.7.2:
222
+ version "1.7.2"
223
+ resolved "https://registry.yarnpkg.com/http-errors/-/http-errors-1.7.2.tgz#4f5029cf13239f31036e5b2e55292bcfbcc85c8f"
224
+ integrity sha512-uUQBt3H/cSIVfch6i1EuPNy/YsRSOUBXTVfZ+yR7Zjez3qjBz6i9+i4zjNaoqcoFVI4lQJ5plg63TvGfRSDCRg==
225
+ dependencies:
226
+ depd "~1.1.2"
227
+ inherits "2.0.3"
228
+ setprototypeof "1.1.1"
229
+ statuses ">= 1.5.0 < 2"
230
+ toidentifier "1.0.0"
231
+
232
+ inherits@2.0.3:
233
+ version "2.0.3"
234
+ resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.3.tgz#633c2c83e3da42a502f52466022480f4208261de"
235
+ integrity sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=
236
+
237
+ is-generator-function@^1.0.7:
238
+ version "1.0.7"
239
+ resolved "https://registry.yarnpkg.com/is-generator-function/-/is-generator-function-1.0.7.tgz#d2132e529bb0000a7f80794d4bdf5cd5e5813522"
240
+ integrity sha512-YZc5EwyO4f2kWCax7oegfuSr9mFz1ZvieNYBEjmukLxgXfBUbxAWGVF7GZf0zidYtoBl3WvC07YK0wT76a+Rtw==
241
+
242
+ isarray@0.0.1:
243
+ version "0.0.1"
244
+ resolved "https://registry.yarnpkg.com/isarray/-/isarray-0.0.1.tgz#8a18acfca9a8f4177e09abfc6038939b05d1eedf"
245
+ integrity sha1-ihis/Kmo9Bd+Cav8YDiTmwXR7t8=
246
+
247
+ keygrip@~1.0.3:
248
+ version "1.0.3"
249
+ resolved "https://registry.yarnpkg.com/keygrip/-/keygrip-1.0.3.tgz#399d709f0aed2bab0a059e0cdd3a5023a053e1dc"
250
+ integrity sha512-/PpesirAIfaklxUzp4Yb7xBper9MwP6hNRA6BGGUFCgbJ+BM5CKBtsoxinNXkLHAr+GXS1/lSlF2rP7cv5Fl+g==
251
+
252
+ koa-compose@^3.0.0:
253
+ version "3.2.1"
254
+ resolved "https://registry.yarnpkg.com/koa-compose/-/koa-compose-3.2.1.tgz#a85ccb40b7d986d8e5a345b3a1ace8eabcf54de7"
255
+ integrity sha1-qFzLQLfZhtjlo0Wzoazo6rz1Tec=
256
+ dependencies:
257
+ any-promise "^1.1.0"
258
+
259
+ koa-compose@^4.1.0:
260
+ version "4.1.0"
261
+ resolved "https://registry.yarnpkg.com/koa-compose/-/koa-compose-4.1.0.tgz#507306b9371901db41121c812e923d0d67d3e877"
262
+ integrity sha512-8ODW8TrDuMYvXRwra/Kh7/rJo9BtOfPc6qO8eAfC80CnCvSjSl0bkRM24X6/XBBEyj0v1nRUQ1LyOy3dbqOWXw==
263
+
264
+ koa-convert@^1.2.0:
265
+ version "1.2.0"
266
+ resolved "https://registry.yarnpkg.com/koa-convert/-/koa-convert-1.2.0.tgz#da40875df49de0539098d1700b50820cebcd21d0"
267
+ integrity sha1-2kCHXfSd4FOQmNFwC1CCDOvNIdA=
268
+ dependencies:
269
+ co "^4.6.0"
270
+ koa-compose "^3.0.0"
271
+
272
+ koa-is-json@^1.0.0:
273
+ version "1.0.0"
274
+ resolved "https://registry.yarnpkg.com/koa-is-json/-/koa-is-json-1.0.0.tgz#273c07edcdcb8df6a2c1ab7d59ee76491451ec14"
275
+ integrity sha1-JzwH7c3Ljfaiwat9We52SRRR7BQ=
276
+
277
+ koa-router@^7.4.0:
278
+ version "7.4.0"
279
+ resolved "https://registry.yarnpkg.com/koa-router/-/koa-router-7.4.0.tgz#aee1f7adc02d5cb31d7d67465c9eacc825e8c5e0"
280
+ integrity sha512-IWhaDXeAnfDBEpWS6hkGdZ1ablgr6Q6pGdXCyK38RbzuH4LkUOpPqPw+3f8l8aTDrQmBQ7xJc0bs2yV4dzcO+g==
281
+ dependencies:
282
+ debug "^3.1.0"
283
+ http-errors "^1.3.1"
284
+ koa-compose "^3.0.0"
285
+ methods "^1.0.1"
286
+ path-to-regexp "^1.1.1"
287
+ urijs "^1.19.0"
288
+
289
+ koa@^2.7.0:
290
+ version "2.7.0"
291
+ resolved "https://registry.yarnpkg.com/koa/-/koa-2.7.0.tgz#7e00843506942b9d82c6cc33749f657c6e5e7adf"
292
+ integrity sha512-7ojD05s2Q+hFudF8tDLZ1CpCdVZw8JQELWSkcfG9bdtoTDzMmkRF6BQBU7JzIzCCOY3xd3tftiy/loHBUYaY2Q==
293
+ dependencies:
294
+ accepts "^1.3.5"
295
+ cache-content-type "^1.0.0"
296
+ content-disposition "~0.5.2"
297
+ content-type "^1.0.4"
298
+ cookies "~0.7.1"
299
+ debug "~3.1.0"
300
+ delegates "^1.0.0"
301
+ depd "^1.1.2"
302
+ destroy "^1.0.4"
303
+ error-inject "^1.0.0"
304
+ escape-html "^1.0.3"
305
+ fresh "~0.5.2"
306
+ http-assert "^1.3.0"
307
+ http-errors "^1.6.3"
308
+ is-generator-function "^1.0.7"
309
+ koa-compose "^4.1.0"
310
+ koa-convert "^1.2.0"
311
+ koa-is-json "^1.0.0"
312
+ on-finished "^2.3.0"
313
+ only "~0.0.2"
314
+ parseurl "^1.3.2"
315
+ statuses "^1.5.0"
316
+ type-is "^1.6.16"
317
+ vary "^1.1.2"
318
+
319
+ media-typer@0.3.0:
320
+ version "0.3.0"
321
+ resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748"
322
+ integrity sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g=
323
+
324
+ methods@^1.0.1:
325
+ version "1.1.2"
326
+ resolved "https://registry.yarnpkg.com/methods/-/methods-1.1.2.tgz#5529a4d67654134edcc5266656835b0f851afcee"
327
+ integrity sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4=
328
+
329
+ mime-db@1.40.0:
330
+ version "1.40.0"
331
+ resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.40.0.tgz#a65057e998db090f732a68f6c276d387d4126c32"
332
+ integrity sha512-jYdeOMPy9vnxEqFRRo6ZvTZ8d9oPb+k18PKoYNYUe2stVEBPPwsln/qWzdbmaIvnhZ9v2P+CuecK+fpUfsV2mA==
333
+
334
+ mime-types@^2.1.18, mime-types@~2.1.24:
335
+ version "2.1.24"
336
+ resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.24.tgz#b6f8d0b3e951efb77dedeca194cff6d16f676f81"
337
+ integrity sha512-WaFHS3MCl5fapm3oLxU4eYDw77IQM2ACcxQ9RIxfaC3ooc6PFuBMGZZsYpvoXS5D5QTWPieo1jjLdAm3TBP3cQ==
338
+ dependencies:
339
+ mime-db "1.40.0"
340
+
341
+ ms@2.0.0:
342
+ version "2.0.0"
343
+ resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8"
344
+ integrity sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=
345
+
346
+ ms@^2.1.1:
347
+ version "2.1.2"
348
+ resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009"
349
+ integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==
350
+
351
+ negotiator@0.6.2:
352
+ version "0.6.2"
353
+ resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.2.tgz#feacf7ccf525a77ae9634436a64883ffeca346fb"
354
+ integrity sha512-hZXc7K2e+PgeI1eDBe/10Ard4ekbfrrqG8Ep+8Jmf4JID2bNg7NvCPOZN+kfF574pFQI7mum2AUqDidoKqcTOw==
355
+
356
+ on-finished@^2.3.0:
357
+ version "2.3.0"
358
+ resolved "https://registry.yarnpkg.com/on-finished/-/on-finished-2.3.0.tgz#20f1336481b083cd75337992a16971aa2d906947"
359
+ integrity sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=
360
+ dependencies:
361
+ ee-first "1.1.1"
362
+
363
+ only@~0.0.2:
364
+ version "0.0.2"
365
+ resolved "https://registry.yarnpkg.com/only/-/only-0.0.2.tgz#2afde84d03e50b9a8edc444e30610a70295edfb4"
366
+ integrity sha1-Kv3oTQPlC5qO3EROMGEKcCle37Q=
367
+
368
+ parseurl@^1.3.2:
369
+ version "1.3.3"
370
+ resolved "https://registry.yarnpkg.com/parseurl/-/parseurl-1.3.3.tgz#9da19e7bee8d12dff0513ed5b76957793bc2e8d4"
371
+ integrity sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==
372
+
373
+ path-to-regexp@^1.1.1:
374
+ version "1.7.0"
375
+ resolved "https://registry.yarnpkg.com/path-to-regexp/-/path-to-regexp-1.7.0.tgz#59fde0f435badacba103a84e9d3bc64e96b9937d"
376
+ integrity sha1-Wf3g9DW62suhA6hOnTvGTpa5k30=
377
+ dependencies:
378
+ isarray "0.0.1"
379
+
380
+ safe-buffer@5.1.2:
381
+ version "5.1.2"
382
+ resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d"
383
+ integrity sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==
384
+
385
+ setprototypeof@1.1.1:
386
+ version "1.1.1"
387
+ resolved "https://registry.yarnpkg.com/setprototypeof/-/setprototypeof-1.1.1.tgz#7e95acb24aa92f5885e0abef5ba131330d4ae683"
388
+ integrity sha512-JvdAWfbXeIGaZ9cILp38HntZSFSo3mWg6xGcJJsd+d4aRMOqauag1C63dJfDw7OaMYwEbHMOxEZ1lqVRYP2OAw==
389
+
390
+ "statuses@>= 1.5.0 < 2", statuses@^1.5.0:
391
+ version "1.5.0"
392
+ resolved "https://registry.yarnpkg.com/statuses/-/statuses-1.5.0.tgz#161c7dac177659fd9811f43771fa99381478628c"
393
+ integrity sha1-Fhx9rBd2Wf2YEfQ3cfqZOBR4Yow=
394
+
395
+ toidentifier@1.0.0:
396
+ version "1.0.0"
397
+ resolved "https://registry.yarnpkg.com/toidentifier/-/toidentifier-1.0.0.tgz#7e1be3470f1e77948bc43d94a3c8f4d7752ba553"
398
+ integrity sha512-yaOH/Pk/VEhBWWTlhI+qXxDFXlejDGcQipMlyxda9nthulaxLZUNcUqFxokp0vcYnvteJln5FNQDRrxj3YcbVw==
399
+
400
+ type-is@^1.6.16:
401
+ version "1.6.18"
402
+ resolved "https://registry.yarnpkg.com/type-is/-/type-is-1.6.18.tgz#4e552cd05df09467dcbc4ef739de89f2cf37c131"
403
+ integrity sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==
404
+ dependencies:
405
+ media-typer "0.3.0"
406
+ mime-types "~2.1.24"
407
+
408
+ typescript@^3.5.1:
409
+ version "3.5.1"
410
+ resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.5.1.tgz#ba72a6a600b2158139c5dd8850f700e231464202"
411
+ integrity sha512-64HkdiRv1yYZsSe4xC1WVgamNigVYjlssIoaH2HcZF0+ijsk5YK2g0G34w9wJkze8+5ow4STd22AynfO6ZYYLw==
412
+
413
+ urijs@^1.19.0:
414
+ version "1.19.1"
415
+ resolved "https://registry.yarnpkg.com/urijs/-/urijs-1.19.1.tgz#5b0ff530c0cbde8386f6342235ba5ca6e995d25a"
416
+ integrity sha512-xVrGVi94ueCJNrBSTjWqjvtgvl3cyOTThp2zaMaFNGp3F542TR6sM3f2o8RqZl+AwteClSVmoCyt0ka4RjQOQg==
417
+
418
+ vary@^1.1.2:
419
+ version "1.1.2"
420
+ resolved "https://registry.yarnpkg.com/vary/-/vary-1.1.2.tgz#2299f02c6ded30d4a5961b0b9f74524a18f634fc"
421
+ integrity sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=
422
+
423
+ ylru@^1.2.0:
424
+ version "1.2.1"
425
+ resolved "https://registry.yarnpkg.com/ylru/-/ylru-1.2.1.tgz#f576b63341547989c1de7ba288760923b27fe84f"
426
+ integrity sha512-faQrqNMzcPCHGVC2aaOINk13K+aaBDUPjGWl0teOXywElLjyVAB6Oe2jj62jHYtwsU49jXhScYbvPENK+6zAvQ==
backend/install.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ sudo apt install jq -y
4
+ pip install -r requirements.txt
5
+ cd id
6
+ npm install
backend/launch.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ pgrep -f gunicorn | xargs kill -9
3
+ kill $(lsof -t -i:3000)
4
+
5
+ cd id
6
+ npm run start &
7
+ sleep 5
8
+ cd -
9
+
10
+ if [[ "$1" == "" ]]; then
11
+ echo "JSON file argument not supplied. Exiting." 1>&2
12
+ exit 1
13
+ fi
14
+
15
+ # Number of GPUs
16
+ N_GPU=$(nvidia-smi -L | wc -l)
17
+
18
+ export FILE=$1
19
+ export GPU_PER_WORKER=`cat "$FILE" | jq -r .gpu_per_worker`
20
+
21
+ # Are there enough GPUs ?
22
+ if [[ $(($N_GPU / $GPU_PER_WORKER)) -eq 0 ]]; then
23
+ echo "Not enough GPUs to run this." 1>&2
24
+ exit 1
25
+ fi
26
+
27
+ N_WORKERS=$(($N_GPU / $GPU_PER_WORKER))
28
+
29
+ echo "File $FILE"
30
+ echo "Available GPUs $N_GPU"
31
+ echo "GPUs per worker $GPU_PER_WORKER"
32
+ echo "Total workers $N_WORKERS"
33
+
34
+ function sys_exit ()
35
+ {
36
+ echo "Ctrl-C caught...performing clean up"
37
+ echo "Cleaning up the servers."
38
+ echo $INST1
39
+ kill -9 $INST1
40
+ exit 2
41
+
42
+ }
43
+
44
+ trap "sys_exit" INT
45
+
46
+ echo "Running server with" ${N_WORKERS} "workers."
47
+ gunicorn --statsd-host=localhost:8125 -w ${N_WORKERS} API --bind=0.0.0.0:6006 --statsd-prefix=transformer-autocomplete -t 600 &
48
+ INST1=$!
49
+
50
+ while true; do sleep 1000; done
backend/machine_configurations/neuralgenv2.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models_to_load": [
3
+ "gpt2/small",
4
+ "gpt2/medium",
5
+ "gpt2/large",
6
+ "gpt2/arxiv-nlp",
7
+
8
+ "gpt",
9
+ "xlnet",
10
+ "distilgpt2/small"
11
+ ],
12
+ "gpu_per_worker": 1
13
+ }
backend/machine_configurations/transformer-autocomplete.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models_to_load": [
3
+ "ctrl",
4
+ "gpt2/xl"
5
+ ],
6
+ "gpu_per_worker": 2,
7
+ "cached_models": {
8
+ "gpt2/xl": "/datadrive/transformer-autocomplete/backend/gpt2-xl-local",
9
+ "ctrl": "/datadrive/transformer-autocomplete/backend/ctrl-local"
10
+ }
11
+ }
backend/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ falcon
2
+ gunicorn
3
+ torch
4
+ transformers
backend/run_pplm_discrim_train.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ # This code is licensed under a non-commercial license.
5
+
6
+ import argparse
7
+ import csv
8
+ import json
9
+ import math
10
+ import time
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.optim
16
+ import torch.optim as optim
17
+ import torch.utils.data as data
18
+ from nltk.tokenize.treebank import TreebankWordDetokenizer
19
+ from torchtext import data as torchtext_data
20
+ from torchtext import datasets
21
+ from tqdm import tqdm, trange
22
+
23
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
24
+
25
+ torch.manual_seed(0)
26
+ np.random.seed(0)
27
+ EPSILON = 1e-10
28
+ device = "cpu"
29
+ example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
30
+ max_length_seq = 100
31
+
32
+
33
+ class ClassificationHead(torch.nn.Module):
34
+ """Classification Head for transformer encoders"""
35
+
36
+ def __init__(self, class_size, embed_size):
37
+ super(ClassificationHead, self).__init__()
38
+ self.class_size = class_size
39
+ self.embed_size = embed_size
40
+ # self.mlp1 = torch.nn.Linear(embed_size, embed_size)
41
+ # self.mlp2 = (torch.nn.Linear(embed_size, class_size))
42
+ self.mlp = torch.nn.Linear(embed_size, class_size)
43
+
44
+ def forward(self, hidden_state):
45
+ # hidden_state = F.relu(self.mlp1(hidden_state))
46
+ # hidden_state = self.mlp2(hidden_state)
47
+ logits = self.mlp(hidden_state)
48
+ return logits
49
+
50
+
51
+ class Discriminator(torch.nn.Module):
52
+ """Transformer encoder followed by a Classification Head"""
53
+
54
+ def __init__(
55
+ self,
56
+ class_size,
57
+ pretrained_model="gpt2-medium",
58
+ cached_mode=False
59
+ ):
60
+ super(Discriminator, self).__init__()
61
+ self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
62
+ self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
63
+ self.embed_size = self.encoder.transformer.config.hidden_size
64
+ self.classifier_head = ClassificationHead(
65
+ class_size=class_size,
66
+ embed_size=self.embed_size
67
+ )
68
+ self.cached_mode = cached_mode
69
+
70
+ def get_classifier(self):
71
+ return self.classifier_head
72
+
73
+ def train_custom(self):
74
+ for param in self.encoder.parameters():
75
+ param.requires_grad = False
76
+ self.classifier_head.train()
77
+
78
+ def avg_representation(self, x):
79
+ mask = x.ne(0).unsqueeze(2).repeat(
80
+ 1, 1, self.embed_size
81
+ ).float().to(device).detach()
82
+ hidden, _ = self.encoder.transformer(x)
83
+ masked_hidden = hidden * mask
84
+ avg_hidden = torch.sum(masked_hidden, dim=1) / (
85
+ torch.sum(mask, dim=1).detach() + EPSILON
86
+ )
87
+ return avg_hidden
88
+
89
+ def forward(self, x):
90
+ if self.cached_mode:
91
+ avg_hidden = x.to(device)
92
+ else:
93
+ avg_hidden = self.avg_representation(x.to(device))
94
+
95
+ logits = self.classifier_head(avg_hidden)
96
+ probs = F.log_softmax(logits, dim=-1)
97
+
98
+ return probs
99
+
100
+
101
+ class Dataset(data.Dataset):
102
+ def __init__(self, X, y):
103
+ """Reads source and target sequences from txt files."""
104
+ self.X = X
105
+ self.y = y
106
+
107
+ def __len__(self):
108
+ return len(self.X)
109
+
110
+ def __getitem__(self, index):
111
+ """Returns one data pair (source and target)."""
112
+ data = {}
113
+ data["X"] = self.X[index]
114
+ data["y"] = self.y[index]
115
+ return data
116
+
117
+
118
+ def collate_fn(data):
119
+ def pad_sequences(sequences):
120
+ lengths = [len(seq) for seq in sequences]
121
+
122
+ padded_sequences = torch.zeros(
123
+ len(sequences),
124
+ max(lengths)
125
+ ).long() # padding value = 0
126
+
127
+ for i, seq in enumerate(sequences):
128
+ end = lengths[i]
129
+ padded_sequences[i, :end] = seq[:end]
130
+
131
+ return padded_sequences, lengths
132
+
133
+ item_info = {}
134
+ for key in data[0].keys():
135
+ item_info[key] = [d[key] for d in data]
136
+
137
+ x_batch, _ = pad_sequences(item_info["X"])
138
+ y_batch = torch.tensor(item_info["y"], dtype=torch.long)
139
+
140
+ return x_batch, y_batch
141
+
142
+
143
+ def cached_collate_fn(data):
144
+ item_info = {}
145
+ for key in data[0].keys():
146
+ item_info[key] = [d[key] for d in data]
147
+
148
+ x_batch = torch.cat(item_info["X"], 0)
149
+ y_batch = torch.tensor(item_info["y"], dtype=torch.long)
150
+
151
+ return x_batch, y_batch
152
+
153
+
154
+ def train_epoch(data_loader, discriminator, optimizer,
155
+ epoch=0, log_interval=10):
156
+ samples_so_far = 0
157
+ discriminator.train_custom()
158
+ for batch_idx, (input_t, target_t) in enumerate(data_loader):
159
+ input_t, target_t = input_t.to(device), target_t.to(device)
160
+
161
+ optimizer.zero_grad()
162
+
163
+ output_t = discriminator(input_t)
164
+ loss = F.nll_loss(output_t, target_t)
165
+ loss.backward(retain_graph=True)
166
+ optimizer.step()
167
+
168
+ samples_so_far += len(input_t)
169
+
170
+ if batch_idx % log_interval == 0:
171
+ print(
172
+ "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
173
+ epoch + 1,
174
+ samples_so_far, len(data_loader.dataset),
175
+ 100 * samples_so_far / len(data_loader.dataset), loss.item()
176
+ )
177
+ )
178
+
179
+
180
+ def evaluate_performance(data_loader, discriminator):
181
+ discriminator.eval()
182
+ test_loss = 0
183
+ correct = 0
184
+ with torch.no_grad():
185
+ for input_t, target_t in data_loader:
186
+ input_t, target_t = input_t.to(device), target_t.to(device)
187
+ output_t = discriminator(input_t)
188
+ # sum up batch loss
189
+ test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
190
+ # get the index of the max log-probability
191
+ pred_t = output_t.argmax(dim=1, keepdim=True)
192
+ correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
193
+
194
+ test_loss /= len(data_loader.dataset)
195
+
196
+ print(
197
+ "Performance on test set: "
198
+ "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
199
+ test_loss, correct, len(data_loader.dataset),
200
+ 100. * correct / len(data_loader.dataset)
201
+ )
202
+ )
203
+
204
+
205
+ def predict(input_sentence, model, classes, cached=False):
206
+ input_t = model.tokenizer.encode(input_sentence)
207
+ input_t = torch.tensor([input_t], dtype=torch.long, device=device)
208
+ if cached:
209
+ input_t = model.avg_representation(input_t)
210
+
211
+ log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
212
+ print("Input sentence:", input_sentence)
213
+ print("Predictions:", ", ".join(
214
+ "{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
215
+ zip(classes, log_probs)
216
+ ))
217
+
218
+
219
+ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
220
+ data_loader = torch.utils.data.DataLoader(dataset=dataset,
221
+ batch_size=batch_size,
222
+ collate_fn=collate_fn)
223
+
224
+ xs = []
225
+ ys = []
226
+ for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
227
+ with torch.no_grad():
228
+ x = x.to(device)
229
+ avg_rep = discriminator.avg_representation(x).cpu().detach()
230
+ avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
231
+ xs += avg_rep_list
232
+ ys += y.cpu().numpy().tolist()
233
+
234
+ data_loader = torch.utils.data.DataLoader(
235
+ dataset=Dataset(xs, ys),
236
+ batch_size=batch_size,
237
+ shuffle=shuffle,
238
+ collate_fn=cached_collate_fn)
239
+
240
+ return data_loader
241
+
242
+
243
+ def train_discriminator(
244
+ dataset, dataset_fp=None, pretrained_model="gpt2-medium",
245
+ epochs=10, batch_size=64, log_interval=10,
246
+ save_model=False, cached=False, no_cuda=False):
247
+ global device
248
+ device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
249
+
250
+ print("Preprocessing {} dataset...".format(dataset))
251
+ start = time.time()
252
+
253
+ if dataset == "SST":
254
+ idx2class = ["positive", "negative", "very positive", "very negative",
255
+ "neutral"]
256
+ class2idx = {c: i for i, c in enumerate(idx2class)}
257
+
258
+ discriminator = Discriminator(
259
+ class_size=len(idx2class),
260
+ pretrained_model=pretrained_model,
261
+ cached_mode=cached
262
+ ).to(device)
263
+
264
+ text = torchtext_data.Field()
265
+ label = torchtext_data.Field(sequential=False)
266
+ train_data, val_data, test_data = datasets.SST.splits(
267
+ text,
268
+ label,
269
+ fine_grained=True,
270
+ train_subtrees=True,
271
+ )
272
+
273
+ x = []
274
+ y = []
275
+ for i in trange(len(train_data), ascii=True):
276
+ seq = TreebankWordDetokenizer().detokenize(
277
+ vars(train_data[i])["text"]
278
+ )
279
+ seq = discriminator.tokenizer.encode(seq)
280
+ seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
281
+ x.append(seq)
282
+ y.append(class2idx[vars(train_data[i])["label"]])
283
+ train_dataset = Dataset(x, y)
284
+
285
+ test_x = []
286
+ test_y = []
287
+ for i in trange(len(test_data), ascii=True):
288
+ seq = TreebankWordDetokenizer().detokenize(
289
+ vars(test_data[i])["text"]
290
+ )
291
+ seq = discriminator.tokenizer.encode(seq)
292
+ seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
293
+ test_x.append(seq)
294
+ test_y.append(class2idx[vars(test_data[i])["label"]])
295
+ test_dataset = Dataset(test_x, test_y)
296
+
297
+ discriminator_meta = {
298
+ "class_size": len(idx2class),
299
+ "embed_size": discriminator.embed_size,
300
+ "pretrained_model": pretrained_model,
301
+ "class_vocab": class2idx,
302
+ "default_class": 2,
303
+ }
304
+
305
+ elif dataset == "clickbait":
306
+ idx2class = ["non_clickbait", "clickbait"]
307
+ class2idx = {c: i for i, c in enumerate(idx2class)}
308
+
309
+ discriminator = Discriminator(
310
+ class_size=len(idx2class),
311
+ pretrained_model=pretrained_model,
312
+ cached_mode=cached
313
+ ).to(device)
314
+
315
+ with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
316
+ data = []
317
+ for i, line in enumerate(f):
318
+ try:
319
+ data.append(eval(line))
320
+ except:
321
+ print("Error evaluating line {}: {}".format(
322
+ i, line
323
+ ))
324
+ continue
325
+ x = []
326
+ y = []
327
+ with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
328
+ for i, line in enumerate(tqdm(f, ascii=True)):
329
+ try:
330
+ d = eval(line)
331
+ seq = discriminator.tokenizer.encode(d["text"])
332
+
333
+ if len(seq) < max_length_seq:
334
+ seq = torch.tensor(
335
+ [50256] + seq, device=device, dtype=torch.long
336
+ )
337
+ else:
338
+ print("Line {} is longer than maximum length {}".format(
339
+ i, max_length_seq
340
+ ))
341
+ continue
342
+ x.append(seq)
343
+ y.append(d["label"])
344
+ except:
345
+ print("Error evaluating / tokenizing"
346
+ " line {}, skipping it".format(i))
347
+ pass
348
+
349
+ full_dataset = Dataset(x, y)
350
+ train_size = int(0.9 * len(full_dataset))
351
+ test_size = len(full_dataset) - train_size
352
+ train_dataset, test_dataset = torch.utils.data.random_split(
353
+ full_dataset, [train_size, test_size]
354
+ )
355
+
356
+ discriminator_meta = {
357
+ "class_size": len(idx2class),
358
+ "embed_size": discriminator.embed_size,
359
+ "pretrained_model": pretrained_model,
360
+ "class_vocab": class2idx,
361
+ "default_class": 1,
362
+ }
363
+
364
+ elif dataset == "toxic":
365
+ idx2class = ["non_toxic", "toxic"]
366
+ class2idx = {c: i for i, c in enumerate(idx2class)}
367
+
368
+ discriminator = Discriminator(
369
+ class_size=len(idx2class),
370
+ pretrained_model=pretrained_model,
371
+ cached_mode=cached
372
+ ).to(device)
373
+
374
+ x = []
375
+ y = []
376
+ with open("datasets/toxic/toxic_train.txt") as f:
377
+ for i, line in enumerate(tqdm(f, ascii=True)):
378
+ try:
379
+ d = eval(line)
380
+ seq = discriminator.tokenizer.encode(d["text"])
381
+
382
+ if len(seq) < max_length_seq:
383
+ seq = torch.tensor(
384
+ [50256] + seq, device=device, dtype=torch.long
385
+ )
386
+ else:
387
+ print("Line {} is longer than maximum length {}".format(
388
+ i, max_length_seq
389
+ ))
390
+ continue
391
+ x.append(seq)
392
+ y.append(int(np.sum(d["label"]) > 0))
393
+ except:
394
+ print("Error evaluating / tokenizing"
395
+ " line {}, skipping it".format(i))
396
+ pass
397
+
398
+ full_dataset = Dataset(x, y)
399
+ train_size = int(0.9 * len(full_dataset))
400
+ test_size = len(full_dataset) - train_size
401
+ train_dataset, test_dataset = torch.utils.data.random_split(
402
+ full_dataset, [train_size, test_size]
403
+ )
404
+
405
+ discriminator_meta = {
406
+ "class_size": len(idx2class),
407
+ "embed_size": discriminator.embed_size,
408
+ "pretrained_model": pretrained_model,
409
+ "class_vocab": class2idx,
410
+ "default_class": 0,
411
+ }
412
+
413
+ else: # if dataset == "generic":
414
+ # This assumes the input dataset is a TSV with the following structure:
415
+ # class \t text
416
+
417
+ if dataset_fp is None:
418
+ raise ValueError("When generic dataset is selected, "
419
+ "dataset_fp needs to be specified aswell.")
420
+
421
+ classes = set()
422
+ with open(dataset_fp) as f:
423
+ csv_reader = csv.reader(f, delimiter="\t")
424
+ for row in tqdm(csv_reader, ascii=True):
425
+ if row:
426
+ classes.add(row[0])
427
+
428
+ idx2class = sorted(classes)
429
+ class2idx = {c: i for i, c in enumerate(idx2class)}
430
+
431
+ discriminator = Discriminator(
432
+ class_size=len(idx2class),
433
+ pretrained_model=pretrained_model,
434
+ cached_mode=cached
435
+ ).to(device)
436
+
437
+ x = []
438
+ y = []
439
+ with open(dataset_fp) as f:
440
+ csv_reader = csv.reader(f, delimiter="\t")
441
+ for i, row in enumerate(tqdm(csv_reader, ascii=True)):
442
+ if row:
443
+ label = row[0]
444
+ text = row[1]
445
+
446
+ try:
447
+ seq = discriminator.tokenizer.encode(text)
448
+ if (len(seq) < max_length_seq):
449
+ seq = torch.tensor(
450
+ [50256] + seq,
451
+ device=device,
452
+ dtype=torch.long
453
+ )
454
+
455
+ else:
456
+ print(
457
+ "Line {} is longer than maximum length {}".format(
458
+ i, max_length_seq
459
+ ))
460
+ continue
461
+
462
+ x.append(seq)
463
+ y.append(class2idx[label])
464
+
465
+ except:
466
+ print("Error tokenizing line {}, skipping it".format(i))
467
+ pass
468
+
469
+ full_dataset = Dataset(x, y)
470
+ train_size = int(0.9 * len(full_dataset))
471
+ test_size = len(full_dataset) - train_size
472
+ train_dataset, test_dataset = torch.utils.data.random_split(
473
+ full_dataset,
474
+ [train_size, test_size]
475
+ )
476
+
477
+ discriminator_meta = {
478
+ "class_size": len(idx2class),
479
+ "embed_size": discriminator.embed_size,
480
+ "pretrained_model": pretrained_model,
481
+ "class_vocab": class2idx,
482
+ "default_class": 0,
483
+ }
484
+
485
+ end = time.time()
486
+ print("Preprocessed {} data points".format(
487
+ len(train_dataset) + len(test_dataset))
488
+ )
489
+ print("Data preprocessing took: {:.3f}s".format(end - start))
490
+
491
+ if cached:
492
+ print("Building representation cache...")
493
+
494
+ start = time.time()
495
+
496
+ train_loader = get_cached_data_loader(
497
+ train_dataset, batch_size, discriminator, shuffle=True
498
+ )
499
+
500
+ test_loader = get_cached_data_loader(
501
+ test_dataset, batch_size, discriminator
502
+ )
503
+
504
+ end = time.time()
505
+ print("Building representation cache took: {:.3f}s".format(end - start))
506
+
507
+ else:
508
+ train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
509
+ batch_size=batch_size,
510
+ shuffle=True,
511
+ collate_fn=collate_fn)
512
+ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
513
+ batch_size=batch_size,
514
+ collate_fn=collate_fn)
515
+
516
+ if save_model:
517
+ with open("{}_classifier_head_meta.json".format(dataset),
518
+ "w") as meta_file:
519
+ json.dump(discriminator_meta, meta_file)
520
+
521
+ optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
522
+
523
+ for epoch in range(epochs):
524
+ start = time.time()
525
+ print("\nEpoch", epoch + 1)
526
+
527
+ train_epoch(
528
+ discriminator=discriminator,
529
+ data_loader=train_loader,
530
+ optimizer=optimizer,
531
+ epoch=epoch,
532
+ log_interval=log_interval
533
+ )
534
+ evaluate_performance(
535
+ data_loader=test_loader,
536
+ discriminator=discriminator
537
+ )
538
+
539
+ end = time.time()
540
+ print("Epoch took: {:.3f}s".format(end - start))
541
+
542
+ print("\nExample prediction")
543
+ predict(example_sentence, discriminator, idx2class, cached)
544
+
545
+ if save_model:
546
+ # torch.save(discriminator.state_dict(),
547
+ # "{}_discriminator_{}.pt".format(
548
+ # args.dataset, epoch + 1
549
+ # ))
550
+ torch.save(discriminator.get_classifier().state_dict(),
551
+ "{}_classifier_head_epoch_{}.pt".format(dataset,
552
+ epoch + 1))
553
+
554
+
555
+ if __name__ == "__main__":
556
+ parser = argparse.ArgumentParser(
557
+ description="Train a discriminator on top of GPT-2 representations")
558
+ parser.add_argument("--dataset", type=str, default="SST",
559
+ choices=("SST", "clickbait", "toxic", "generic"),
560
+ help="dataset to train the discriminator on."
561
+ "In case of generic, the dataset is expected"
562
+ "to be a TSBV file with structure: class \\t text")
563
+ parser.add_argument("--dataset_fp", type=str, default="",
564
+ help="File path of the dataset to use. "
565
+ "Needed only in case of generic datadset")
566
+ parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
567
+ help="Pretrained model to use as encoder")
568
+ parser.add_argument("--epochs", type=int, default=10, metavar="N",
569
+ help="Number of training epochs")
570
+ parser.add_argument("--batch_size", type=int, default=64, metavar="N",
571
+ help="input batch size for training (default: 64)")
572
+ parser.add_argument("--log_interval", type=int, default=10, metavar="N",
573
+ help="how many batches to wait before logging training status")
574
+ parser.add_argument("--save_model", action="store_true",
575
+ help="whether to save the model")
576
+ parser.add_argument("--cached", action="store_true",
577
+ help="whether to cache the input representations")
578
+ parser.add_argument("--no_cuda", action="store_true",
579
+ help="use to turn off cuda")
580
+ args = parser.parse_args()
581
+
582
+ train_discriminator(**(vars(args)))
entrypoint.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ defined_envs=$(printf '${%s} ' $(awk "END { for (name in ENVIRON) { print ( name ~ /NGINX_/ ) ? name : \"\" } }" < /dev/null ))
4
+
5
+ envsubst "$defined_envs" < nginx.conf > /etc/nginx/nginx.conf
6
+
7
+ nginx && node server/dist/server.js
front/.vscode/settings.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ // Configure glob patterns for excluding files and folders in searches. Inherits all glob patterns from the files.exclude setting.
3
+ "search.exclude": {
4
+ "dist": true,
5
+ "build": true,
6
+ }
7
+ }
front/assets/Icon-info.svg ADDED
front/assets/Salesforce_logo.svg ADDED
front/assets/Uber_logo.svg ADDED
front/assets/cross-collab.svg ADDED
front/assets/github-buttons.js ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ * github-buttons v2.2.10
3
+ * (c) 2019 なつき
4
+ * @license BSD-2-Clause
5
+ */
6
+ /**
7
+ * Julien: just modified to add a `transform: scale(1.5);` on the .widget
8
+ */
9
+ !function(){"use strict";var e=window.document,t=e.location,o=window.encodeURIComponent,r=window.decodeURIComponent,n=window.Math,a=window.HTMLElement,i=window.XMLHttpRequest,l="https://unpkg.com/github-buttons@2.2.10/dist/buttons.html",c=i&&i.prototype&&"withCredentials"in i.prototype,d=c&&a&&a.prototype.attachShadow&&!a.prototype.attachShadow.prototype,s=function(e,t,o){e.addEventListener?e.addEventListener(t,o):e.attachEvent("on"+t,o)},u=function(e,t,o){e.removeEventListener?e.removeEventListener(t,o):e.detachEvent("on"+t,o)},h=function(e,t,o){var r=function(n){return u(e,t,r),o(n)};s(e,t,r)},f=function(e,t,o){var r=function(n){if(t.test(e.readyState))return u(e,"readystatechange",r),o(n)};s(e,"readystatechange",r)},p=function(e){return function(t,o,r){var n=e.createElement(t);if(o)for(var a in o){var i=o[a];null!=i&&(null!=n[a]?n[a]=i:n.setAttribute(a,i))}if(r)for(var l=0,c=r.length;l<c;l++){var d=r[l];n.appendChild("string"==typeof d?e.createTextNode(d):d)}return n}},g=p(e),b=function(e){var t;return function(){t||(t=1,e.apply(this,arguments))}},m="body{margin:0}a{color:#24292e;text-decoration:none;outline:0}.octicon{display:inline-block;vertical-align:text-top;fill:currentColor}.widget{ transform: scale(1.5); display:inline-block;overflow:hidden;font-family:-apple-system, BlinkMacSystemFont, \"Segoe UI\", Helvetica, Arial, sans-serif;font-size:0;white-space:nowrap;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}.btn,.social-count{display:inline-block;height:14px;padding:2px 5px;font-size:11px;font-weight:600;line-height:14px;vertical-align:bottom;cursor:pointer;border:1px solid #c5c9cc;border-radius:0.25em}.btn{background-color:#eff3f6;background-image:-webkit-linear-gradient(top, #fafbfc, #eff3f6 90%);background-image:-moz-linear-gradient(top, #fafbfc, #eff3f6 90%);background-image:linear-gradient(180deg, #fafbfc, #eff3f6 90%);background-position:-1px -1px;background-repeat:repeat-x;background-size:110% 110%;border-color:rgba(27,31,35,0.2);-ms-filter:\"progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFFAFBFC', endColorstr='#FFEEF2F5')\";*filter:progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFFAFBFC', endColorstr='#FFEEF2F5')}.btn:active{background-color:#e9ecef;background-image:none;border-color:#a5a9ac;border-color:rgba(27,31,35,0.35);box-shadow:inset 0 0.15em 0.3em rgba(27,31,35,0.15)}.btn:focus,.btn:hover{background-color:#e6ebf1;background-image:-webkit-linear-gradient(top, #f0f3f6, #e6ebf1 90%);background-image:-moz-linear-gradient(top, #f0f3f6, #e6ebf1 90%);background-image:linear-gradient(180deg, #f0f3f6, #e6ebf1 90%);border-color:#a5a9ac;border-color:rgba(27,31,35,0.35);-ms-filter:\"progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFF0F3F6', endColorstr='#FFE5EAF0')\";*filter:progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFF0F3F6', endColorstr='#FFE5EAF0')}.social-count{position:relative;margin-left:5px;background-color:#fff}.social-count:focus,.social-count:hover{color:#0366d6}.social-count b,.social-count i{position:absolute;top:50%;left:0;display:block;width:0;height:0;margin:-4px 0 0 -4px;border:solid transparent;border-width:4px 4px 4px 0;_line-height:0;_border-top-color:red !important;_border-bottom-color:red !important;_border-left-color:red !important;_filter:chroma(color=red)}.social-count b{border-right-color:#c5c9cc}.social-count i{margin-left:-3px;border-right-color:#fff}.lg .btn,.lg .social-count{height:16px;padding:5px 10px;font-size:12px;line-height:16px}.lg .social-count{margin-left:6px}.lg .social-count b,.lg .social-count i{margin:-5px 0 0 -5px;border-width:5px 5px 5px 0}.lg .social-count i{margin-left:-4px}\n",v={"mark-github":{width:16,height:16,path:'<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"/>'},eye:{width:16,height:16,path:'<path fill-rule="evenodd" d="M8.06 2C3 2 0 8 0 8s3 6 8.06 6C13 14 16 8 16 8s-3-6-7.94-6zM8 12c-2.2 0-4-1.78-4-4 0-2.2 1.8-4 4-4 2.22 0 4 1.8 4 4 0 2.22-1.78 4-4 4zm2-4c0 1.11-.89 2-2 2-1.11 0-2-.89-2-2 0-1.11.89-2 2-2 1.11 0 2 .89 2 2z"/>'},star:{width:14,height:16,path:'<path fill-rule="evenodd" d="M14 6l-4.9-.64L7 1 4.9 5.36 0 6l3.6 3.26L2.67 14 7 11.67 11.33 14l-.93-4.74L14 6z"/>'},"repo-forked":{width:10,height:16,path:'<path fill-rule="evenodd" d="M8 1a1.993 1.993 0 0 0-1 3.72V6L5 8 3 6V4.72A1.993 1.993 0 0 0 2 1a1.993 1.993 0 0 0-1 3.72V6.5l3 3v1.78A1.993 1.993 0 0 0 5 15a1.993 1.993 0 0 0 1-3.72V9.5l3-3V4.72A1.993 1.993 0 0 0 8 1zM2 4.2C1.34 4.2.8 3.65.8 3c0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2zm3 10c-.66 0-1.2-.55-1.2-1.2 0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2zm3-10c-.66 0-1.2-.55-1.2-1.2 0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2z"/>'},"issue-opened":{width:14,height:16,path:'<path fill-rule="evenodd" d="M7 2.3c3.14 0 5.7 2.56 5.7 5.7s-2.56 5.7-5.7 5.7A5.71 5.71 0 0 1 1.3 8c0-3.14 2.56-5.7 5.7-5.7zM7 1C3.14 1 0 4.14 0 8s3.14 7 7 7 7-3.14 7-7-3.14-7-7-7zm1 3H6v5h2V4zm0 6H6v2h2v-2z"/>'},"cloud-download":{width:16,height:16,path:'<path fill-rule="evenodd" d="M9 12h2l-3 3-3-3h2V7h2v5zm3-8c0-.44-.91-3-4.5-3C5.08 1 3 2.92 3 5 1.02 5 0 6.52 0 8c0 1.53 1 3 3 3h3V9.7H3C1.38 9.7 1.3 8.28 1.3 8c0-.17.05-1.7 1.7-1.7h1.3V5c0-1.39 1.56-2.7 3.2-2.7 2.55 0 3.13 1.55 3.2 1.8v1.2H12c.81 0 2.7.22 2.7 2.2 0 2.09-2.25 2.2-2.7 2.2h-2V11h2c2.08 0 4-1.16 4-3.5C16 5.06 14.08 4 12 4z"/>'}},w={},x=function(e,t,o){var r=p(e.ownerDocument),n=e.appendChild(r("style",{type:"text/css"}));n.styleSheet?n.styleSheet.cssText=m:n.appendChild(e.ownerDocument.createTextNode(m));var a,l,d=r("a",{className:"btn",href:t.href,target:"_blank",innerHTML:(a=t["data-icon"],l=/^large$/i.test(t["data-size"])?16:14,a=(""+a).toLowerCase().replace(/^octicon-/,""),{}.hasOwnProperty.call(v,a)||(a="mark-github"),'<svg version="1.1" width="'+l*v[a].width/v[a].height+'" height="'+l+'" viewBox="0 0 '+v[a].width+" "+v[a].height+'" class="octicon octicon-'+a+'" aria-hidden="true">'+v[a].path+"</svg>"),"aria-label":t["aria-label"]||void 0},[" ",r("span",{},[t["data-text"]||""])]);/\.github\.com$/.test("."+d.hostname)?/^https?:\/\/((gist\.)?github\.com\/[^\/?#]+\/[^\/?#]+\/archive\/|github\.com\/[^\/?#]+\/[^\/?#]+\/releases\/download\/|codeload\.github\.com\/)/.test(d.href)&&(d.target="_top"):(d.href="#",d.target="_self");var u,h,g,x,y=e.appendChild(r("div",{className:"widget"+(/^large$/i.test(t["data-size"])?" lg":"")},[d]));/^(true|1)$/i.test(t["data-show-count"])&&"github.com"===d.hostname&&(u=d.pathname.replace(/^(?!\/)/,"/").match(/^\/([^\/?#]+)(?:\/([^\/?#]+)(?:\/(?:(subscription)|(fork)|(issues)|([^\/?#]+)))?)?(?:[\/?#]|$)/))&&!u[6]?(u[2]?(h="/repos/"+u[1]+"/"+u[2],u[3]?(x="subscribers_count",g="watchers"):u[4]?(x="forks_count",g="network"):u[5]?(x="open_issues_count",g="issues"):(x="stargazers_count",g="stargazers")):(h="/users/"+u[1],g=x="followers"),function(e,t){var o=w[e]||(w[e]=[]);if(!(o.push(t)>1)){var r=b(function(){for(delete w[e];t=o.shift();)t.apply(null,arguments)});if(c){var n=new i;s(n,"abort",r),s(n,"error",r),s(n,"load",function(){var e;try{e=JSON.parse(n.responseText)}catch(e){return void r(e)}r(200!==n.status,e)}),n.open("GET",e),n.send()}else{var a=this||window;a._=function(e){a._=null,r(200!==e.meta.status,e.data)};var l=p(a.document)("script",{async:!0,src:e+(/\?/.test(e)?"&":"?")+"callback=_"}),d=function(){a._&&a._({meta:{}})};s(l,"load",d),s(l,"error",d),l.readyState&&f(l,/de|m/,d),a.document.getElementsByTagName("head")[0].appendChild(l)}}}.call(this,"https://api.github.com"+h,function(e,t){if(!e){var n=t[x];y.appendChild(r("a",{className:"social-count",href:t.html_url+"/"+g,target:"_blank","aria-label":n+" "+x.replace(/_count$/,"").replace("_"," ").slice(0,n<2?-1:void 0)+" on GitHub"},[r("b"),r("i"),r("span",{},[(""+n).replace(/\B(?=(\d{3})+(?!\d))/g,",")])]))}o&&o(y)})):o&&o(y)},y=window.devicePixelRatio||1,C=function(e){return(y>1?n.ceil(n.round(e*y)/y*2)/2:n.ceil(e))||0},F=function(e,t){e.style.width=t[0]+"px",e.style.height=t[1]+"px"},k=function(t,r){if(null!=t&&null!=r)if(t.getAttribute&&(t=function(e){for(var t={href:e.href,title:e.title,"aria-label":e.getAttribute("aria-label")},o=["icon","text","size","show-count"],r=0,n=o.length;r<n;r++){var a="data-"+o[r];t[a]=e.getAttribute(a)}return null==t["data-text"]&&(t["data-text"]=e.textContent||e.innerText),t}(t)),d){var a=g("span",{title:t.title||void 0});x(a.attachShadow({mode:"closed"}),t,function(){r(a)})}else{var i=g("iframe",{src:"javascript:0",title:t.title||void 0,allowtransparency:!0,scrolling:"no",frameBorder:0});F(i,[0,0]),i.style.border="none";var c=function(){var a,d=i.contentWindow;try{a=d.document.body}catch(t){return void e.body.appendChild(i.parentNode.removeChild(i))}u(i,"load",c),x.call(d,a,t,function(e){var a=function(e){var t=e.offsetWidth,o=e.offsetHeight;if(e.getBoundingClientRect){var r=e.getBoundingClientRect();t=n.max(t,C(r.width)),o=n.max(o,C(r.height))}return[t,o]}(e);i.parentNode.removeChild(i),h(i,"load",function(){F(i,a)}),i.src=l+"#"+(i.name=function(e){var t=[];for(var r in e){var n=e[r];null!=n&&t.push(o(r)+"="+o(n))}return t.join("&")}(t)),r(i)})};s(i,"load",c),e.body.appendChild(i)}};t.protocol+"//"+t.host+t.pathname===l?x(e.body,function(e){for(var t={},o=e.split("&"),n=0,a=o.length;n<a;n++){var i=o[n];if(""!==i){var l=i.split("=");t[r(l[0])]=null!=l[1]?r(l.slice(1).join("=")):void 0}}return t}(window.name||t.hash.replace(/^#/,""))):function(t){if(/m/.test(e.readyState)||!/g/.test(e.readyState)&&!e.documentElement.doScroll)setTimeout(t);else if(e.addEventListener){var o=b(t);h(e,"DOMContentLoaded",o),h(window,"load",o)}else f(e,/m/,t)}(function(){for(var t=e.querySelectorAll?e.querySelectorAll("a.github-button"):function(){for(var t=[],o=e.getElementsByTagName("a"),r=0,n=o.length;r<n;r++)~(" "+o[r].className+" ").replace(/[ \t\n\f\r]+/g," ").indexOf(" github-button ")&&t.push(o[r]);return t}(),o=0,r=t.length;o<r;o++)!function(e){k(e,function(t){e.parentNode.replaceChild(t,e)})}(t[o])})}();
front/assets/huggingface_logo.svg ADDED
front/assets/icon-back.svg ADDED
front/assets/icon-publish.svg ADDED
front/assets/iconmonstr-download-14.svg ADDED
front/assets/iconmonstr-media-control-55.svg ADDED
front/assets/iconmonstr-share-11-purple.svg ADDED
front/assets/iconmonstr-share-11.svg ADDED
front/assets/oval.svg ADDED
front/assets/tail-spin.svg ADDED
front/assets/thumbnail-large-distilgpt2.png ADDED
front/assets/thumbnail-large-pplm.png ADDED
front/assets/thumbnail-large.png ADDED
front/assets/unicorn-tweaked.svg ADDED
front/favicon.ico ADDED
front/js-src/Api.ts ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { c } from './lib/Log';
2
+
3
+
4
+ interface AutocompleteOutput {
5
+ sentences: {
6
+ value: string;
7
+ time: number;
8
+ }[];
9
+ time: number;
10
+ }
11
+
12
+ export class Api {
13
+
14
+ private static ENDPOINT =
15
+ // `http://coconut-proxy.huggingface.test`
16
+ // `http://coconuthf.eastus.cloudapp.azure.com:6006`
17
+ // "http://localhost:6006"
18
+ `https://transformer.huggingface.co`
19
+ ;
20
+ static shared = new Api();
21
+
22
+ private path(p: string): string {
23
+ return `${Api.ENDPOINT}/${p}`;
24
+ }
25
+
26
+ private async postAutocomplete(
27
+ params: {
28
+ context: string;
29
+ model_size?: string; /// 'small' | 'medium',
30
+ top_p?: number; /// float between 0 and 1
31
+ temperature?: number; /// float between 0 and 100
32
+ step_size?: number;
33
+ kl_scale?: number;
34
+ gm_scale?: number;
35
+ num_iterations?: number;
36
+ gen_length?: number;
37
+ max_time?: number; /// <- if we want to limit the response time. (in sec)
38
+ bow_or_discrim?: string;
39
+ use_sampling?: boolean;
40
+ }
41
+ ): Promise<AutocompleteOutput> {
42
+
43
+ const path = this.path(`autocomplete/${params.model_size || ""}`);
44
+
45
+ const response = await fetch(path, {
46
+ method: 'POST',
47
+ headers: { 'Content-Type': 'application/json' },
48
+ body: JSON.stringify(params),
49
+ });
50
+ return await response.json() as AutocompleteOutput;
51
+ }
52
+
53
+ /**
54
+ * Demo-specific helpers
55
+ */
56
+ async postWithSettings(
57
+ params: {
58
+ context: string;
59
+ }
60
+ ): Promise<AutocompleteOutput> {
61
+ /// Retrieve all settings params then launch the request.
62
+ const model_size =
63
+ document.querySelector('.decoder-settings .setting.model_size .js-val')!.textContent
64
+ || undefined;
65
+
66
+ const parseSliderVal = (sel: string): number | undefined => {
67
+ const x = document.querySelector(sel);
68
+ if (x && x.textContent) {
69
+ return Number(x.textContent);
70
+ }
71
+ return undefined;
72
+ };
73
+
74
+ const top_p = parseSliderVal('.decoder-settings .setting.top_p .js-val');
75
+ const temperature = parseSliderVal('.decoder-settings .setting.temperature .js-val');
76
+ const step_size = parseSliderVal('.decoder-settings .setting.step_size .js-val');
77
+ const kl_scale = parseSliderVal('.decoder-settings .setting.kl_scale .js-val');
78
+ const gm_scale = parseSliderVal('.decoder-settings .setting.gm_scale .js-val');
79
+ const num_iterations = parseSliderVal('.decoder-settings .setting.num_iterations .js-val');
80
+ const gen_length = parseSliderVal('.decoder-settings .setting.gen_length .js-val');
81
+ const max_time = parseSliderVal('.decoder-settings .setting.max_time .js-val');
82
+
83
+ const bow_or_discrim = (
84
+ document.querySelector<HTMLInputElement>('.decoder-settings input[name=bow_or_discrim]:checked') || {}
85
+ ).value;
86
+ const use_sampling = (
87
+ document.querySelector<HTMLInputElement>('.decoder-settings input[name=use_sampling]') || {}
88
+ ).checked;
89
+
90
+ return this.postAutocomplete({
91
+ ...params,
92
+ model_size,
93
+ top_p,
94
+ temperature,
95
+ step_size,
96
+ kl_scale,
97
+ gm_scale,
98
+ num_iterations,
99
+ gen_length,
100
+ max_time,
101
+ bow_or_discrim,
102
+ use_sampling,
103
+ });
104
+ }
105
+
106
+ /**
107
+ * Edit AJAX endpoint
108
+ *
109
+ * Contrary to the autocomplete endpoint,
110
+ * this is on server,
111
+ * not on backend.
112
+ */
113
+ async postEdit(body: any): Promise<boolean> {
114
+ const doc = (<any>window).doc as { [index: string]: string };
115
+ if (!doc || !doc.longId) {
116
+ throw new Error(`invalid doc`);
117
+ }
118
+
119
+ const path = `/edit/${doc.model}/${doc.longId}/${doc.shortId}`;
120
+
121
+ const response = await fetch(path, {
122
+ method: 'POST',
123
+ headers: { 'Content-Type': 'application/json' },
124
+ body: JSON.stringify(body),
125
+ });
126
+ return response.ok;
127
+ }
128
+
129
+ /**
130
+ * Duplicate AJAX endpoint
131
+ *
132
+ * Contrary to the autocomplete endpoint,
133
+ * this is on server,
134
+ * not on backend.
135
+ */
136
+ async postDuplicate(): Promise<string> {
137
+ const doc = (<any>window).doc as { [index: string]: string };
138
+ if (!doc || !doc.shortId) {
139
+ throw new Error(`invalid doc`);
140
+ }
141
+
142
+ const path = `/duplicate/${doc.shortId}`;
143
+ const response = await fetch(path, {
144
+ method: 'POST',
145
+ headers: { 'Content-Type': 'application/json' },
146
+ });
147
+ const url = await response.text();
148
+ c.log('[new url]', url);
149
+
150
+ return url;
151
+ }
152
+ }
153
+
front/js-src/Mention.ts ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ interface Datum {
3
+ id: string;
4
+ value: string;
5
+ }
6
+
7
+
8
+ export class Mention {
9
+ static Keys = {
10
+ TAB: 9,
11
+ ENTER: 13,
12
+ ESCAPE: 27,
13
+ UP: 38,
14
+ DOWN: 40,
15
+ };
16
+ static numberIsNaN = (x: any) => x !== x;
17
+ private isOpen = false;
18
+ /**
19
+ * index of currently selected item.
20
+ */
21
+ private itemIndex = 0;
22
+ private mentionCharPos: number | undefined = undefined;
23
+ private cursorPos: number | undefined = undefined;
24
+ private values = [] as Datum[];
25
+ private suspendMouseEnter = false;
26
+ private options = {
27
+ source: (searchTerm: string, renderList: Function, mentionChar: string) => {},
28
+ renderItem: (item: Datum, searchTerm: string) => {
29
+ return `${item.value}`;
30
+ },
31
+ onSelect: (item: DOMStringMap, insertItem: (item: DOMStringMap) => void) => {
32
+ insertItem(item);
33
+ },
34
+ mentionDenotationChars: ['@'],
35
+ showDenotationChar: true,
36
+ allowedChars: /^[a-zA-Z0-9_]*$/,
37
+ minChars: 0,
38
+ maxChars: 31,
39
+ offsetTop: 2,
40
+ offsetLeft: 0,
41
+ /**
42
+ * Whether or not the denotation character(s) should be isolated. For example, to avoid mentioning in an email.
43
+ */
44
+ isolateCharacter: false,
45
+ fixMentionsToQuill: false,
46
+ defaultMenuOrientation: 'bottom',
47
+ dataAttributes: ['id', 'value', 'denotationChar', 'link', 'target'],
48
+ linkTarget: '_blank',
49
+ onOpen: () => true,
50
+ onClose: () => true,
51
+ // Style options
52
+ listItemClass: 'ql-mention-list-item',
53
+ mentionContainerClass: 'ql-mention-list-container',
54
+ mentionListClass: 'ql-mention-list',
55
+ };
56
+ /// HTML elements
57
+ private mentionContainer = document.createElement('div');
58
+ private mentionList = document.createElement('ul');
59
+
60
+
61
+ constructor(
62
+ private quill: Quill,
63
+ ) {
64
+ this.mentionContainer.className = this.options.mentionContainerClass;
65
+ this.mentionContainer.style.cssText = 'display: none; position: absolute;';
66
+ this.mentionContainer.onmousemove = this.onContainerMouseMove.bind(this);
67
+
68
+ if (this.options.fixMentionsToQuill) {
69
+ this.mentionContainer.style.width = 'auto';
70
+ }
71
+
72
+ this.mentionList.className = this.options.mentionListClass;
73
+ this.mentionContainer.appendChild(this.mentionList);
74
+
75
+ this.quill.container.appendChild(this.mentionContainer);
76
+
77
+ quill.on('text-change', this.onTextChange.bind(this));
78
+ quill.on('selection-change', this.onSelectionChange.bind(this));
79
+
80
+ quill.keyboard.addBinding({
81
+ key: Mention.Keys.ENTER,
82
+ }, this.selectHandler.bind(this));
83
+ quill.keyboard.bindings[Mention.Keys.ENTER].unshift(
84
+ quill.keyboard.bindings[Mention.Keys.ENTER].pop()
85
+ );
86
+ /// ^^ place it at beginning of bindings.
87
+
88
+ quill.keyboard.addBinding({
89
+ key: Mention.Keys.ESCAPE,
90
+ }, this.escapeHandler.bind(this));
91
+
92
+ quill.keyboard.addBinding({
93
+ key: Mention.Keys.UP,
94
+ }, this.upHandler.bind(this));
95
+
96
+ quill.keyboard.addBinding({
97
+ key: Mention.Keys.DOWN,
98
+ }, this.downHandler.bind(this));
99
+
100
+ document.addEventListener("keypress", e => {
101
+ /// Quick’n’dirty hack.
102
+ if (! this.quill.hasFocus()) {
103
+ return ;
104
+ }
105
+ setTimeout(() => {
106
+ this.setCursorPos();
107
+ this.quill.removeFormat(this.cursorPos! - 1, 1, 'silent');
108
+ }, 0);
109
+ });
110
+ }
111
+
112
+ selectHandler() {
113
+ if (this.isOpen) {
114
+ this.selectItem();
115
+ return false;
116
+ }
117
+ return true;
118
+ }
119
+
120
+ escapeHandler() {
121
+ if (this.isOpen) {
122
+ this.hideMentionList();
123
+ return false;
124
+ }
125
+ return true;
126
+ }
127
+
128
+ upHandler() {
129
+ if (this.isOpen) {
130
+ this.prevItem();
131
+ return false;
132
+ }
133
+ return true;
134
+ }
135
+
136
+ downHandler() {
137
+ if (this.isOpen) {
138
+ this.nextItem();
139
+ return false;
140
+ }
141
+ return true;
142
+ }
143
+
144
+ showMentionList() {
145
+ this.mentionContainer.style.visibility = 'hidden';
146
+ this.mentionContainer.style.display = '';
147
+ this.setMentionContainerPosition();
148
+ this.setIsOpen(true);
149
+ }
150
+
151
+ hideMentionList() {
152
+ this.mentionContainer.style.display = 'none';
153
+ this.setIsOpen(false);
154
+ }
155
+
156
+
157
+ private highlightItem(scrollItemInView = true) {
158
+ const childNodes = Array.from(this.mentionList.childNodes) as HTMLLIElement[];
159
+ for (const node of childNodes) {
160
+ node.classList.remove('selected');
161
+ }
162
+ childNodes[this.itemIndex].classList.add('selected');
163
+
164
+ if (scrollItemInView) {
165
+ const itemHeight = childNodes[this.itemIndex].offsetHeight;
166
+ const itemPos = this.itemIndex * itemHeight;
167
+ const containerTop = this.mentionContainer.scrollTop;
168
+ const containerBottom = containerTop + this.mentionContainer.offsetHeight;
169
+
170
+ if (itemPos < containerTop) {
171
+ // Scroll up if the item is above the top of the container
172
+ this.mentionContainer.scrollTop = itemPos;
173
+ } else if (itemPos > (containerBottom - itemHeight)) {
174
+ // scroll down if any part of the element is below the bottom of the container
175
+ this.mentionContainer.scrollTop += (itemPos - containerBottom) + itemHeight;
176
+ }
177
+ }
178
+ }
179
+
180
+ private getItemData(): DOMStringMap {
181
+ const node = this.mentionList.childNodes[this.itemIndex] as HTMLElement;
182
+ const { link } = node.dataset;
183
+ const itemTarget = node.dataset.target;
184
+ if (link !== undefined) {
185
+ node.dataset.value = `<a href="${link}" target=${itemTarget || this.options.linkTarget}>${node.dataset.value}`;
186
+ }
187
+ return node.dataset;
188
+ }
189
+
190
+ onContainerMouseMove() {
191
+ this.suspendMouseEnter = false;
192
+ }
193
+
194
+ selectItem() {
195
+ const data = this.getItemData();
196
+ this.options.onSelect(data, (asyncData) => {
197
+ this.insertItem(asyncData);
198
+ });
199
+ this.hideMentionList();
200
+ }
201
+
202
+ insertItem(data: DOMStringMap) {
203
+ const render = data;
204
+ if (render === null) {
205
+ return ;
206
+ }
207
+ if (!this.options.showDenotationChar) {
208
+ render.denotationChar = '';
209
+ }
210
+ if (this.cursorPos === undefined) {
211
+ throw new Error(`Invalid this.cursorPos`);
212
+ }
213
+ if (!render.value) {
214
+ throw new Error(`Didn't receive value from server.`);
215
+ }
216
+
217
+ this.quill.insertText(this.cursorPos, render.value, 'bold', Quill.sources.USER);
218
+ this.quill.setSelection(this.cursorPos + render.value.length, 0);
219
+ this.setCursorPos();
220
+ this.hideMentionList();
221
+ }
222
+
223
+ onItemMouseEnter(e: MouseEvent) {
224
+ if (this.suspendMouseEnter) {
225
+ return ;
226
+ }
227
+ const index = Number(
228
+ (e.target as HTMLLIElement).dataset.index
229
+ );
230
+ if (! Mention.numberIsNaN(index) && index !== this.itemIndex) {
231
+ this.itemIndex = index;
232
+ this.highlightItem(false);
233
+ }
234
+ }
235
+
236
+ onItemClick(e: MouseEvent) {
237
+ e.stopImmediatePropagation();
238
+ e.preventDefault();
239
+ this.itemIndex = Number(
240
+ (e.currentTarget as HTMLElement).dataset.index
241
+ );
242
+ this.highlightItem();
243
+ this.selectItem();
244
+ }
245
+
246
+ private attachDataValues(element: HTMLLIElement, data: Datum): HTMLLIElement {
247
+ for (const [key, value] of Object.entries(data)) {
248
+ if (this.options.dataAttributes.includes(key)) {
249
+ element.dataset[key] = value;
250
+ } else {
251
+ delete element.dataset[key];
252
+ }
253
+ }
254
+ return element;
255
+ }
256
+
257
+ renderList(mentionChar: string, data: Datum[], searchTerm: string = "") {
258
+ if (data.length > 0) {
259
+ this.values = data;
260
+ this.mentionList.innerHTML = '';
261
+
262
+ for (const [i, datum] of data.entries()) {
263
+ const li = document.createElement('li');
264
+ li.className = this.options.listItemClass;
265
+ li.dataset.index = `${i}`;
266
+ // li.innerHTML = this.options.renderItem(datum, searchTerm);
267
+ li.innerText = datum.value.replace(/\n/g, "↵");
268
+ /// ^^
269
+ li.onmouseenter = this.onItemMouseEnter.bind(this);
270
+ li.dataset.denotationChar = mentionChar;
271
+ li.onclick = this.onItemClick.bind(this);
272
+ this.mentionList.appendChild(
273
+ this.attachDataValues(li, datum)
274
+ );
275
+ }
276
+ this.itemIndex = 0;
277
+ this.highlightItem();
278
+ this.showMentionList();
279
+ } else {
280
+ this.hideMentionList();
281
+ }
282
+ }
283
+
284
+ nextItem() {
285
+ this.itemIndex = (this.itemIndex + 1) % this.values.length;
286
+ this.suspendMouseEnter = true;
287
+ this.highlightItem();
288
+ }
289
+
290
+ prevItem() {
291
+ this.itemIndex = ((this.itemIndex + this.values.length) - 1) % this.values.length;
292
+ this.suspendMouseEnter = true;
293
+ this.highlightItem();
294
+ }
295
+
296
+ private hasValidChars(s: string) {
297
+ return this.options.allowedChars.test(s);
298
+ }
299
+
300
+ private containerBottomIsNotVisible(topPos: number, containerPos: ClientRect | DOMRect) {
301
+ const mentionContainerBottom = topPos + this.mentionContainer.offsetHeight + containerPos.top;
302
+ return mentionContainerBottom > window.pageYOffset + window.innerHeight;
303
+ }
304
+
305
+ private containerRightIsNotVisible(leftPos: number, containerPos: ClientRect | DOMRect) {
306
+ if (this.options.fixMentionsToQuill) {
307
+ return false;
308
+ }
309
+ const rightPos = leftPos + this.mentionContainer.offsetWidth + containerPos.left;
310
+ const browserWidth = window.pageXOffset + document.documentElement.clientWidth;
311
+ return rightPos > browserWidth;
312
+ }
313
+
314
+ private setIsOpen(isOpen: boolean) {
315
+ if (this.isOpen !== isOpen) {
316
+ if (isOpen) {
317
+ this.options.onOpen();
318
+ } else {
319
+ this.options.onClose();
320
+ }
321
+ this.isOpen = isOpen;
322
+ }
323
+ }
324
+
325
+ private setMentionContainerPosition() {
326
+ const containerPos = this.quill.container.getBoundingClientRect();
327
+ /// vv Here we always trigger from the cursor.
328
+ if (this.cursorPos === undefined) {
329
+ throw new Error(`Invalid this.cursorPos`);
330
+ }
331
+ const mentionCharPos = this.quill.getBounds(this.cursorPos);
332
+ const containerHeight = this.mentionContainer.offsetHeight;
333
+
334
+ let topPos = this.options.offsetTop;
335
+ let leftPos = this.options.offsetLeft;
336
+
337
+ // handle horizontal positioning
338
+ if (this.options.fixMentionsToQuill) {
339
+ const rightPos = 0;
340
+ this.mentionContainer.style.right = `${rightPos}px`;
341
+ } else {
342
+ leftPos += mentionCharPos.left;
343
+ }
344
+
345
+ if (this.containerRightIsNotVisible(leftPos, containerPos)) {
346
+ const containerWidth = this.mentionContainer.offsetWidth + this.options.offsetLeft;
347
+ const quillWidth = containerPos.width;
348
+ leftPos = quillWidth - containerWidth;
349
+ }
350
+
351
+ // handle vertical positioning
352
+ if (this.options.defaultMenuOrientation === 'top') {
353
+ // Attempt to align the mention container with the top of the quill editor
354
+ if (this.options.fixMentionsToQuill) {
355
+ topPos = -1 * (containerHeight + this.options.offsetTop);
356
+ } else {
357
+ topPos = mentionCharPos.top - (containerHeight + this.options.offsetTop);
358
+ }
359
+
360
+ // default to bottom if the top is not visible
361
+ if (topPos + containerPos.top <= 0) {
362
+ let overMentionCharPos = this.options.offsetTop;
363
+
364
+ if (this.options.fixMentionsToQuill) {
365
+ overMentionCharPos += containerPos.height;
366
+ } else {
367
+ overMentionCharPos += mentionCharPos.bottom;
368
+ }
369
+
370
+ topPos = overMentionCharPos;
371
+ }
372
+ } else {
373
+ // Attempt to align the mention container with the bottom of the quill editor
374
+ if (this.options.fixMentionsToQuill) {
375
+ topPos += containerPos.height;
376
+ } else {
377
+ topPos += mentionCharPos.bottom;
378
+ }
379
+
380
+ // default to the top if the bottom is not visible
381
+ if (this.containerBottomIsNotVisible(topPos, containerPos)) {
382
+ let overMentionCharPos = this.options.offsetTop * -1;
383
+
384
+ if (!this.options.fixMentionsToQuill) {
385
+ overMentionCharPos += mentionCharPos.top;
386
+ }
387
+
388
+ topPos = overMentionCharPos - containerHeight;
389
+ }
390
+ }
391
+
392
+ this.mentionContainer.style.top = `${topPos}px`;
393
+ this.mentionContainer.style.left = `${leftPos}px`;
394
+ this.mentionContainer.style.visibility = 'visible';
395
+ }
396
+
397
+
398
+ /**
399
+ * HF Helpers for manual trigger
400
+ */
401
+ setCursorPos() {
402
+ const range = this.quill.getSelection();
403
+ if (range) {
404
+ this.cursorPos = range.index;
405
+ } else {
406
+ this.quill.setSelection(this.quill.getLength(), 0);
407
+ /// ^^ place cursor at the end of input by default.
408
+ this.cursorPos = this.quill.getLength();
409
+ }
410
+ }
411
+ getCursorPos(): number {
412
+ return this.cursorPos!;
413
+ }
414
+ trigger(values: string[]) {
415
+ this.renderList("", values.map(x => {
416
+ return { id: x, value: x };
417
+ }), "");
418
+ }
419
+
420
+ onSomethingChange() {
421
+ /// We trigger manually so here we can _probably_ just always close.
422
+ this.hideMentionList();
423
+ }
424
+
425
+ onTextChange(delta: Delta, oldDelta: Delta, source: Sources) {
426
+ if (source === 'user') {
427
+ this.onSomethingChange();
428
+ }
429
+ }
430
+
431
+ onSelectionChange(range: RangeStatic) {
432
+ if (range && range.length === 0) {
433
+ this.onSomethingChange();
434
+ } else {
435
+ this.hideMentionList();
436
+ }
437
+ }
438
+ }
439
+
440
+
441
+ Quill.register('modules/mention', Mention);
front/js-src/controller.ts ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Api } from './Api';
2
+ import { Mention } from './Mention';
3
+ import { c } from './lib/Log';
4
+ import { Utils } from './lib/Utils';
5
+ import { VanillaTilt } from './vanilla-tilt';
6
+ import { ShareScreenshotModal, SavePublishModal } from './modals';
7
+
8
+ /// We experimented with a couple of different build systems
9
+ /// to integrate Quill (for instance module-then-postprocessing
10
+ /// like in `web3d`) but none worked really well so we just
11
+ /// hotlink the js and basically copy/paste the @types/quill
12
+ /// declaration here.
13
+ /// Update: we now use rollup (for html2canvas), but quill is
14
+ /// still a pain so it's still not in the same bundle.
15
+
16
+ const DEBUG = false;
17
+ /// ^^ when debugging the quill integration, add the quill.snow.css to layout.hbs
18
+ /// <link href="/front/node_modules/quill/dist/quill.snow.css" rel="stylesheet">
19
+ /// <link href="/front/node_modules/quill/dist/quill.core.css" rel="stylesheet">
20
+ /// We tried doing it programmatically here but it's a bit slow.
21
+ if (DEBUG) {
22
+ document.head.insertAdjacentHTML(
23
+ 'beforeend',
24
+ `<link href="/front/node_modules/quill/dist/quill.snow.css" rel="stylesheet">`
25
+ );
26
+ /// ^^ add css to debug. Do it as early as possible.
27
+ }
28
+
29
+ enum Page {
30
+ app, landing, model
31
+ }
32
+ const App = {
33
+ page:
34
+ (document.body.classList.contains('app')) ? Page.app
35
+ : (document.body.classList.contains('landing')) ? Page.landing
36
+ : Page.model
37
+ ,
38
+ editable: document.body.dataset.editable === 'true',
39
+ header: {
40
+ shuffleBtn: document.querySelector('header .js-shuffle') as HTMLAnchorElement,
41
+ triggerBtn: document.querySelector('header .js-trigger') as HTMLAnchorElement,
42
+ mainInfoBtn: document.querySelector('header .title .info') as HTMLImageElement,
43
+ shareBtn: document.querySelector<HTMLAnchorElement>('header .js-share'),
44
+ saveBtn: document.querySelector<HTMLAnchorElement>('header .js-save'),
45
+ duplicateBtn: document.querySelector<HTMLAnchorElement>('header .js-duplicate'),
46
+ },
47
+ shareScreenBtn: document.querySelector('.page-container .js-share') as HTMLAnchorElement,
48
+ loaderEditor: document.querySelector('.page-container .js-loader') as HTMLImageElement,
49
+ sliders: Array.from(
50
+ document.querySelectorAll('.decoder-settings input.slider')
51
+ ) as HTMLInputElement[],
52
+ INITIAL_CONTENT: {} as Delta,
53
+ /**
54
+ * Helper function to more cleanly route different page types.
55
+ */
56
+ onLoad: (p: Page, callback: () => void) => {
57
+ if (p === App.page) {
58
+ document.addEventListener('DOMContentLoaded', () => {
59
+ callback();
60
+ });
61
+ }
62
+ },
63
+ };
64
+
65
+ const PROMPTS = [
66
+ `Before boarding your rocket to Mars, remember to pack these items`,
67
+ `In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.`,
68
+ `Legolas and Gimli advanced on the orcs, raising their weapons with a harrowing war cry.`,
69
+ `Today, scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth`,
70
+ `
71
+ Thor: The Tesseract belongs on Asgard, no human is a match for it.
72
+ Tony turns to leave, but Steve stops him.
73
+ Steve: You're not going alone!
74
+ Tony: You gonna stop me?
75
+ `.replace(/\t/g, "").trim().concat("\n"),
76
+ ];
77
+
78
+
79
+
80
+
81
+ App.onLoad(Page.app, () => {
82
+ const modalScreenshot = new ShareScreenshotModal;
83
+
84
+ const opts: QuillOptionsStatic = DEBUG
85
+ ? {
86
+ theme: 'snow',
87
+ modules: {
88
+ mention: {},
89
+ },
90
+ }
91
+ : {
92
+ theme: undefined,
93
+ // formats: [],
94
+ modules: {
95
+ toolbar: [],
96
+ mention: {},
97
+ },
98
+ }
99
+ ;
100
+ if (! App.editable) {
101
+ opts.readOnly = true;
102
+ }
103
+ const quill = new Quill('div.editor', opts);
104
+ const mention = quill.getModule('mention') as Mention;
105
+ (<any>window).quill = quill;
106
+ const QUILL_C = (<any>window).QUILL_C;
107
+ if (QUILL_C) {
108
+ quill.setContents(QUILL_C);
109
+ }
110
+
111
+
112
+
113
+ quill.container.appendChild(App.loaderEditor);
114
+ quill.container.appendChild(App.shareScreenBtn);
115
+
116
+ //
117
+ // div.editor .ql-container <-- quill.container
118
+ // +--------------------------------+
119
+ // | div.ql-editor contenteditable | <-- quill.root
120
+ // | +----------------------------+ |
121
+ // | | | |
122
+ // | | | |
123
+ // | +----------------------------+ |
124
+ // +--------------------------------+
125
+ //
126
+
127
+ quill.keyboard.addBinding({ key: Mention.Keys.TAB }, () => {
128
+ triggerAutocomplete();
129
+ });
130
+ quill.keyboard.bindings[Mention.Keys.TAB].unshift(
131
+ quill.keyboard.bindings[Mention.Keys.TAB].pop()
132
+ );
133
+ /// ^^ important.
134
+ /// ^^ place it at beginning of bindings.
135
+
136
+
137
+ const triggerAutocomplete = async () => {
138
+ /// vv position loader
139
+ mention.setCursorPos();
140
+ const cursorBbox = quill.getBounds(mention.getCursorPos());
141
+ App.loaderEditor.style.top = `${cursorBbox.top - 4}px`;
142
+ App.loaderEditor.style.left = `${cursorBbox.left + 4}px`;
143
+ App.loaderEditor.classList.remove('hide');
144
+
145
+ /// vv Launch api request.
146
+ const text = quill.getText(0, mention.getCursorPos());
147
+ // ^^ That is so much simpler that what we used to do
148
+ // when we were embbedding objects like in `quill-mention`.
149
+ c.debug(
150
+ `%c[About to launch autocomplete for]`,
151
+ `color: green;`,
152
+ text,
153
+ );
154
+ const o = await Api.shared.postWithSettings({ context: text });
155
+ App.loaderEditor.classList.add('hide');
156
+
157
+ /// vv Trigger mention module.
158
+ for (const x of o.sentences) {
159
+ c.log(x.value);
160
+ }
161
+ mention.trigger(
162
+ o.sentences.map(x => x.value)
163
+ );
164
+ };
165
+
166
+
167
+ App.header.duplicateBtn?.addEventListener('click', async (e) => {
168
+ e.preventDefault();
169
+ const url = await Api.shared.postDuplicate();
170
+ window.location.href = url;
171
+ });
172
+
173
+
174
+ if (! App.editable) {
175
+ return ;
176
+ }
177
+ /**
178
+ * vv Below is only in editable mode.
179
+ */
180
+
181
+ const modalSave = new SavePublishModal(quill);
182
+
183
+ App.header.shuffleBtn.addEventListener('click', (e) => {
184
+ e.preventDefault();
185
+ quill.setText(
186
+ Utils.randomItem(PROMPTS)
187
+ );
188
+ quill.setSelection(quill.getLength(), 0);
189
+ /// ^^ github.com/quilljs/quill/issues/2635
190
+ triggerAutocomplete();
191
+ });
192
+ App.header.triggerBtn.addEventListener('click', (e) => {
193
+ e.preventDefault();
194
+ triggerAutocomplete();
195
+ });
196
+ App.header.shareBtn?.addEventListener('click', async (e) => {
197
+ e.preventDefault();
198
+ const text = `Write With Transformer via @huggingface`;
199
+ window.open(`https://twitter.com/share?url=${ encodeURIComponent(window.location.href) }&text=${ encodeURIComponent(text) }`);
200
+ });
201
+ App.header.saveBtn?.addEventListener('click', (e) => {
202
+ e.preventDefault();
203
+ mention.hideMentionList();
204
+ modalSave.show();
205
+ });
206
+
207
+ App.shareScreenBtn.addEventListener('click', async (e) => {
208
+ e.preventDefault();
209
+ mention.hideMentionList();
210
+ modalScreenshot.show();
211
+ });
212
+ quill.on('text-change', () => {
213
+ App.shareScreenBtn.classList.remove('hide'); /// <- we use a fadeout effect.
214
+ const hasTextFromAI = quill.getContents()
215
+ .ops
216
+ .some(op => op.attributes && op.attributes.bold === true)
217
+ ;
218
+ App.shareScreenBtn.classList.toggle('fadeout', ! hasTextFromAI);
219
+ });
220
+ document.addEventListener('click', (e) => {
221
+ /// Handle clicks on links inside the editor.
222
+ if (! (
223
+ e.target instanceof HTMLAnchorElement
224
+ && e.target.closest('div.ql-editor') !== null
225
+ )) {
226
+ return ;
227
+ }
228
+ /// Ok, let's do this.
229
+ e.preventDefault();
230
+ e.stopPropagation();
231
+ const href = e.target.getAttribute('href'); /// <- caution, get the original string.
232
+ c.debug(`[click]`, href);
233
+ if (href === '#js-shuffle') {
234
+ App.header.shuffleBtn.click();
235
+ } else {
236
+ window.open(e.target.href);
237
+ }
238
+ });
239
+ document.addEventListener("scroll", e => {
240
+ const trigger = document.getElementsByClassName("js-trigger")[0] as HTMLAnchorElement;
241
+ if (scrollY > 100) {
242
+ trigger.style.position = "fixed";
243
+ trigger.style.top = "10px";
244
+ trigger.style.border = "1px solid blue";
245
+ trigger.style.backgroundColor = "white";
246
+ trigger.style.borderRadius = "100px";
247
+ trigger.style.padding = "5px";
248
+ trigger.style.zIndex = "1";
249
+ trigger.style.left = "50%";
250
+ trigger.style.transform = "translateX(-50%)";
251
+ } else {
252
+ trigger.style.position = "relative";
253
+ trigger.style.top = "auto";
254
+ trigger.style.border = "none";
255
+ trigger.style.backgroundColor = "white";
256
+ trigger.style.borderRadius = "0";
257
+ trigger.style.padding = "0";
258
+ trigger.style.zIndex = "1";
259
+ trigger.style.left = "auto"
260
+ }
261
+ });
262
+
263
+ /**
264
+ * Settings
265
+ */
266
+ const handleSliderChange = (slider: HTMLInputElement) => {
267
+ const div = slider.parentNode as HTMLDivElement;
268
+ const spanVal = div.querySelector('.js-val') as HTMLSpanElement;
269
+ const value = Number.isInteger(slider.valueAsNumber)
270
+ ? slider.valueAsNumber
271
+ : Number(slider.valueAsNumber.toFixed(2))
272
+ ;
273
+ const valueKey = `value-${value}`;
274
+ if (slider.dataset[valueKey]) {
275
+ spanVal.innerText = slider.dataset[valueKey]!;
276
+ } else {
277
+ spanVal.innerText = value.toString();
278
+ }
279
+ const min = Number(slider.getAttribute('min'));
280
+ const max = Number(slider.getAttribute('max'));
281
+ if (value < min + (max - min) / 3) {
282
+ spanVal.className = "js-val green";
283
+ } else if (value < min + 2 * (max - min) / 3) {
284
+ spanVal.className = "js-val orange";
285
+ } else {
286
+ spanVal.className = "js-val red";
287
+ }
288
+ const isInverted = slider.classList.contains('js-inverted');
289
+ if (isInverted) {
290
+ if (spanVal.classList.contains('green')) {
291
+ spanVal.classList.remove('green');
292
+ spanVal.classList.add('red');
293
+ } else if (spanVal.classList.contains('red')) {
294
+ spanVal.classList.remove('red');
295
+ spanVal.classList.add('green');
296
+ }
297
+ }
298
+ };
299
+ for (const slider of App.sliders) {
300
+ handleSliderChange(slider);
301
+ slider.addEventListener('input', () => {
302
+ handleSliderChange(slider);
303
+ });
304
+ }
305
+ });
306
+
307
+
308
+
309
+ App.onLoad(Page.landing, () => {
310
+ /**
311
+ * VanillaTilt
312
+ */
313
+ VanillaTilt.init(document.querySelectorAll("[data-tilt]"), {
314
+ glare: true,
315
+ scale: 1.06,
316
+ 'max-glare': 0.3,
317
+ speed: 400,
318
+ });
319
+ });
front/js-src/lib/Log.ts ADDED
@@ -0,0 +1 @@
 
 
1
+ export const c = console;
front/js-src/lib/Utils.ts ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export class Utils {
3
+ private static escapeMap = {
4
+ /// From underscore.js
5
+ '&': '&amp;',
6
+ '<': '&lt;',
7
+ '>': '&gt;',
8
+ '"': '&quot;',
9
+ "'": '&#x27;',
10
+ '`': '&#x60;'
11
+ };
12
+
13
+ /**
14
+ * Escape a message's content for insertion into html.
15
+ */
16
+ static escape(s: string): string {
17
+ let x = s;
18
+ for (const [k, v] of Object.entries(this.escapeMap)) {
19
+ x = x.replace(new RegExp(k, 'g'), v);
20
+ }
21
+ return x.replace(/\n/g, '<br>');
22
+ }
23
+
24
+ /**
25
+ * Opposite of escape.
26
+ */
27
+ static unescape(s: string): string {
28
+ let x = s.replace(/<br>/g, '\n');
29
+ for (const [k, v] of Object.entries(this.escapeMap)) {
30
+ x = x.replace(new RegExp(v, 'g'), k);
31
+ }
32
+ return x;
33
+ }
34
+
35
+ /**
36
+ * "Real" modulo (always >= 0), not remainder.
37
+ */
38
+ static mod(a: number, n: number): number {
39
+ return ((a % n) + n) % n;
40
+ }
41
+
42
+ /**
43
+ * Noop object with arbitrary number of nested attributes that are also noop.
44
+ */
45
+ static deepNoop() {
46
+ const noop = new Proxy(() => {}, {
47
+ get: () => noop
48
+ });
49
+ return noop;
50
+ }
51
+
52
+ /**
53
+ * Capitalize
54
+ */
55
+ static capitalize(s: string): string {
56
+ return s.charAt(0).toUpperCase() + s.slice(1);
57
+ }
58
+
59
+ /**
60
+ * Returns a promise that will resolve after the specified time
61
+ * @param ms Number of ms to wait
62
+ */
63
+ static delay(ms: number): Promise<void> {
64
+ return new Promise((resolve, reject) => {
65
+ setTimeout(() => resolve(), ms);
66
+ });
67
+ }
68
+
69
+ /**
70
+ * Random element from array
71
+ */
72
+ static randomItem<T>(arr: T[]): T {
73
+ return arr[Math.floor(Math.random()*arr.length)];
74
+ }
75
+ }
76
+
front/js-src/modals.ts ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Utils } from './lib/Utils';
2
+ import html2canvas from 'html2canvas';
3
+ import { c } from './lib/Log';
4
+ import { Api } from './Api';
5
+
6
+ abstract class Modal {
7
+ protected div: HTMLDivElement;
8
+ protected doneBtn: HTMLAnchorElement | null;
9
+ protected loader: HTMLImageElement;
10
+ constructor(className: string) {
11
+ this.div = document.querySelector(`div.modal.${className}`) as HTMLDivElement;
12
+ this.doneBtn = this.div.querySelector<HTMLAnchorElement>('.js-close');
13
+ this.loader = this.div.querySelector('.js-loader') as HTMLImageElement;
14
+
15
+ this.doneBtn?.addEventListener('click', (e) => {
16
+ e.preventDefault();
17
+ this.hide();
18
+ });
19
+ this.div.addEventListener('click', (e) => {
20
+ if (e.target === this.div) {
21
+ c.debug(`modal:background.click`);
22
+ this.hide();
23
+ }
24
+ });
25
+ }
26
+ /**
27
+ * Hooks: Implement those to perform the actual work done on show and hide.
28
+ */
29
+ abstract performBeforeShow(): Promise<void>;
30
+ abstract performShow(): Promise<void>;
31
+ abstract performHide(): Promise<void>;
32
+ async show() {
33
+ await this.performBeforeShow();
34
+ this.div.classList.add('fadeout');
35
+ this.div.classList.remove('hide');
36
+ await Utils.delay(100);
37
+ this.div.classList.remove('fadeout');
38
+ await this.performShow();
39
+ this.loader.classList.add('hide');
40
+ }
41
+ async hide() {
42
+ this.div.classList.add('fadeout');
43
+ await Utils.delay(200);
44
+ this.div.classList.add('hide');
45
+ this.div.classList.remove('fadeout');
46
+ await this.performHide();
47
+ }
48
+ }
49
+
50
+ export class ShareScreenshotModal extends Modal {
51
+ private imResult = this.div.querySelector('.js-result') as HTMLImageElement;
52
+
53
+ constructor() {
54
+ super(`share-screenshot`);
55
+ }
56
+ async performBeforeShow() {
57
+ this.loader.classList.remove('hide');
58
+ }
59
+ async performShow() {
60
+ await Utils.delay(800); /// <- for good ux
61
+ const el = document.querySelector('div.page-inner') as HTMLDivElement;
62
+ const canvas = await html2canvas(el, {
63
+ logging: false, /// <- inoperant in our version of html2canvas.
64
+ onclone: (doc) => {
65
+ const clonedEl = doc.querySelector('div.page-inner') as HTMLDivElement;
66
+ clonedEl.classList.add('html2canvas');
67
+ const watermark = doc.querySelector('div.watermark') as HTMLDivElement;
68
+ watermark.style.visibility = `visible`;
69
+ }
70
+ });
71
+ this.imResult.src = canvas.toDataURL();
72
+ }
73
+ async performHide() {
74
+ this.imResult.src = "";
75
+ }
76
+ }
77
+
78
+ export class SavePublishModal extends Modal {
79
+ private saveBtn = this.div.querySelector('.js-save') as HTMLAnchorElement;
80
+ private form = this.div.querySelector('form') as HTMLFormElement;
81
+ constructor(
82
+ private quill: Quill
83
+ ) {
84
+ super(`save-publish`);
85
+
86
+ /// vv Url fields auto-select.
87
+ const urlInputs = Array.from(
88
+ this.div.querySelectorAll('.doc-url')
89
+ ) as HTMLInputElement[];
90
+ for (const x of urlInputs) {
91
+ x.addEventListener('focus', () => {
92
+ x.select();
93
+ });
94
+ }
95
+
96
+ this.saveBtn.addEventListener('click', (e) => {
97
+ e.preventDefault();
98
+ if (! this.form.reportValidity()) {
99
+ /// Form is invalid.
100
+ return ;
101
+ }
102
+ this.save();
103
+ });
104
+ this.form.addEventListener('submit', (e) => {
105
+ e.preventDefault();
106
+ this.saveBtn.click();
107
+ });
108
+ }
109
+ async performBeforeShow() {}
110
+ async performShow() {}
111
+ async performHide() {}
112
+ async save() {
113
+ this.loader.classList.remove('hide');
114
+
115
+ const inputTitle = this.div.querySelector('.doc-title') as HTMLInputElement;
116
+ const title = inputTitle.value;
117
+ const contents = this.quill.getContents();
118
+ c.log(JSON.stringify({ title, contents }));
119
+
120
+ const success = await Api.shared.postEdit({ title, contents });
121
+ await Utils.delay(800); /// <- for good ux
122
+
123
+ if (success) {
124
+ this.loader.classList.add('hide');
125
+ this.hide();
126
+ /// For now we always redirect to the edit url here:
127
+ /// vv
128
+ const inputEditUrl = this.div.querySelector('.doc-edit-url') as HTMLInputElement;
129
+ window.location.href = inputEditUrl.value;
130
+ } else {
131
+ window.alert(`did not manage to save`);
132
+ }
133
+ }
134
+ }
front/js-src/quill.d.ts ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ // import { Blot } from "../node_modules/parchment/dist/src/blot/abstract/blot";
3
+ interface Blot {}
4
+ interface Delta {
5
+ ops: DeltaOperation[];
6
+ }
7
+
8
+ /**
9
+ * A stricter type definition would be:
10
+ *
11
+ * type DeltaOperation ({ insert: any } | { delete: number } | { retain: number }) & OptionalAttributes;
12
+ *
13
+ * But this would break a lot of existing code as it would require manual discrimination of the union types.
14
+ */
15
+ type DeltaOperation = { insert?: any, delete?: number, retain?: number } & OptionalAttributes;
16
+ type Sources = "api" | "user" | "silent";
17
+
18
+ interface Key {
19
+ key: string | number;
20
+ shortKey?: boolean;
21
+ }
22
+
23
+ interface StringMap {
24
+ [key: string]: any;
25
+ }
26
+
27
+ interface OptionalAttributes {
28
+ attributes?: StringMap;
29
+ }
30
+
31
+ type TextChangeHandler = (delta: Delta, oldContents: Delta, source: Sources) => any;
32
+ type SelectionChangeHandler = (range: RangeStatic, oldRange: RangeStatic, source: Sources) => any;
33
+ type EditorChangeHandler = ((name: "text-change", delta: Delta, oldContents: Delta, source: Sources) => any)
34
+ | ((name: "selection-change", range: RangeStatic, oldRange: RangeStatic, source: Sources) => any);
35
+
36
+ interface KeyboardStatic {
37
+ addBinding(key: Key, callback: (range: RangeStatic, context: any) => void): void;
38
+ addBinding(key: Key, context: any, callback: (range: RangeStatic, context: any) => void): void;
39
+ bindings: { [index: number]: any[] };
40
+ }
41
+
42
+ interface ClipboardStatic {
43
+ convert(html?: string): Delta;
44
+ addMatcher(selectorOrNodeType: string|number, callback: (node: any, delta: Delta) => Delta): void;
45
+ dangerouslyPasteHTML(html: string, source?: Sources): void;
46
+ dangerouslyPasteHTML(index: number, html: string, source?: Sources): void;
47
+ }
48
+
49
+ interface QuillOptionsStatic {
50
+ debug?: string | boolean;
51
+ modules?: StringMap;
52
+ placeholder?: string;
53
+ readOnly?: boolean;
54
+ theme?: string;
55
+ formats?: string[];
56
+ bounds?: HTMLElement | string;
57
+ scrollingContainer?: HTMLElement | string;
58
+ strict?: boolean;
59
+ }
60
+
61
+ interface BoundsStatic {
62
+ bottom: number;
63
+ left: number;
64
+ right: number;
65
+ top: number;
66
+ height: number;
67
+ width: number;
68
+ }
69
+
70
+ declare interface RangeStatic {
71
+ index: number;
72
+ length: number;
73
+ }
74
+
75
+ declare class RangeStatic implements RangeStatic {
76
+ constructor();
77
+ index: number;
78
+ length: number;
79
+ }
80
+
81
+ interface EventEmitter {
82
+ on(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
83
+ on(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
84
+ on(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
85
+ once(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
86
+ once(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
87
+ once(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
88
+ off(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
89
+ off(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
90
+ off(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
91
+ }
92
+
93
+ declare class Quill {
94
+ /**
95
+ * @private Internal API
96
+ */
97
+ root: HTMLDivElement;
98
+ container: HTMLElement; /// <- used by quill-mention
99
+ clipboard: ClipboardStatic;
100
+ scroll: Blot;
101
+ keyboard: KeyboardStatic;
102
+ constructor(container: string | Element, options?: QuillOptionsStatic);
103
+ deleteText(index: number, length: number, source?: Sources): Delta;
104
+ disable(): void;
105
+ enable(enabled?: boolean): void;
106
+ getContents(index?: number, length?: number): Delta;
107
+ getLength(): number;
108
+ getText(index?: number, length?: number): string;
109
+ insertEmbed(index: number, type: string, value: any, source?: Sources): Delta;
110
+ insertText(index: number, text: string, source?: Sources): Delta;
111
+ insertText(index: number, text: string, format: string, value: any, source?: Sources): Delta;
112
+ insertText(index: number, text: string, formats: StringMap, source?: Sources): Delta;
113
+ /**
114
+ * @deprecated Remove in 2.0. Use clipboard.dangerouslyPasteHTML(index: number, html: string, source: Sources)
115
+ */
116
+ pasteHTML(index: number, html: string, source?: Sources): string;
117
+ /**
118
+ * @deprecated Remove in 2.0. Use clipboard.dangerouslyPasteHTML(html: string, source: Sources): void;
119
+ */
120
+ pasteHTML(html: string, source?: Sources): string;
121
+ setContents(delta: Delta, source?: Sources): Delta;
122
+ setText(text: string, source?: Sources): Delta;
123
+ update(source?: Sources): void;
124
+ updateContents(delta: Delta, source?: Sources): Delta;
125
+
126
+ format(name: string, value: any, source?: Sources): Delta;
127
+ formatLine(index: number, length: number, source?: Sources): Delta;
128
+ formatLine(index: number, length: number, format: string, value: any, source?: Sources): Delta;
129
+ formatLine(index: number, length: number, formats: StringMap, source?: Sources): Delta;
130
+ formatText(index: number, length: number, source?: Sources): Delta;
131
+ formatText(index: number, length: number, format: string, value: any, source?: Sources): Delta;
132
+ formatText(index: number, length: number, formats: StringMap, source?: Sources): Delta;
133
+ formatText(range: RangeStatic, format: string, value: any, source?: Sources): Delta;
134
+ formatText(range: RangeStatic, formats: StringMap, source?: Sources): Delta;
135
+ getFormat(range?: RangeStatic): StringMap;
136
+ getFormat(index: number, length?: number): StringMap;
137
+ removeFormat(index: number, length: number, source?: Sources): Delta;
138
+
139
+ blur(): void;
140
+ focus(): void;
141
+ getBounds(index: number, length?: number): BoundsStatic;
142
+ getSelection(focus: true): RangeStatic;
143
+ getSelection(focus?: false): RangeStatic | null;
144
+ hasFocus(): boolean;
145
+ setSelection(index: number, length: number, source?: Sources): void;
146
+ setSelection(range: RangeStatic, source?: Sources): void;
147
+
148
+ // static methods: debug, import, register, find
149
+ static debug(level: string|boolean): void;
150
+ static import(path: string): any;
151
+ static register(path: string, def: any, suppressWarning?: boolean): void;
152
+ static register(defs: StringMap, suppressWarning?: boolean): void;
153
+ static find(domNode: Node, bubble?: boolean): Quill | any;
154
+
155
+ addContainer(classNameOrDomNode: string|Node, refNode?: Node): any;
156
+ getModule(name: string): any;
157
+
158
+ // Blot interface is not exported on Parchment
159
+ getIndex(blot: any): number;
160
+ getLeaf(index: number): any;
161
+ getLine(index: number): [any, number];
162
+ getLines(index?: number, length?: number): any[];
163
+ getLines(range: RangeStatic): any[];
164
+
165
+ // EventEmitter methods
166
+ on(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
167
+ on(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
168
+ on(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
169
+ once(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
170
+ once(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
171
+ once(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
172
+ off(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
173
+ off(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
174
+ off(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
175
+
176
+ static sources: {
177
+ API: 'api',
178
+ SILENT: 'silent',
179
+ USER: 'user',
180
+ };
181
+ }
front/js-src/vanilla-tilt.ts ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export namespace VanillaTilt {
2
+ /**
3
+ * Options which configures the tilting
4
+ */
5
+ export interface TiltOptions {
6
+ /**
7
+ * Reverse the tilt direction
8
+ */
9
+ reverse?: boolean;
10
+ /**
11
+ * Max tilt rotation (degrees)
12
+ */
13
+ max?: number;
14
+ /**
15
+ * Transform perspective, the lower the more extreme the tilt gets.
16
+ */
17
+ perspective?: number;
18
+ /**
19
+ * 2 = 200%, 1.5 = 150%, etc..
20
+ */
21
+ scale?: number;
22
+ /**
23
+ * Speed of the enter/exit transition
24
+ */
25
+ speed?: number;
26
+ /**
27
+ * Set a transition on enter/exit.
28
+ */
29
+ transition?: boolean;
30
+ /**
31
+ * What axis should be disabled. Can be X or Y.
32
+ */
33
+ axis?: null | "x" | "y";
34
+ /**
35
+ * If the tilt effect has to be reset on exit.
36
+ */
37
+ reset?: boolean;
38
+ /**
39
+ * Easing on enter/exit.
40
+ */
41
+ easing?: string;
42
+ /**
43
+ * Added (@julien-c)
44
+ */
45
+ glare?: boolean;
46
+ 'max-glare'?: number;
47
+ }
48
+
49
+ export interface TiltValues {
50
+ /**
51
+ * The current tilt on the X axis
52
+ */
53
+ tiltX: number;
54
+ /**
55
+ * The current tilt on the Y axis
56
+ */
57
+ tiltY: number;
58
+ /**
59
+ * The current percentage on the X axis
60
+ */
61
+ percentageX: number;
62
+ /**
63
+ * The current percentage on the Y axis
64
+ */
65
+ percentageY: number;
66
+ }
67
+
68
+ export interface HTMLVanillaTiltElement extends HTMLElement {
69
+ vanillaTilt: VanillaTilt
70
+ }
71
+ }
72
+
73
+
74
+ export class VanillaTilt {
75
+ width: number | null;
76
+ height: number | null;
77
+ left: number | null;
78
+ top: number | null;
79
+ element: VanillaTilt.HTMLVanillaTiltElement;
80
+ settings: VanillaTilt.TiltOptions;
81
+ reverse : -1 | 1;
82
+ glare: boolean;
83
+ glarePrerender: boolean;
84
+ transitionTimeout: number | null;
85
+ updateCall: number | null;
86
+ glareElementWrapper: HTMLElement;
87
+ glareElement: HTMLElement;
88
+ updateBind: () => void;
89
+ resetBind: () => void;
90
+ onMouseEnterBind: (e: Event) => void;
91
+ onMouseMoveBind: (e: Event) => void;
92
+ onMouseLeaveBind: (e: Event) => void;
93
+ event: MouseEvent;
94
+
95
+ constructor(element, settings: VanillaTilt.TiltOptions = {}) {
96
+ if (!(element instanceof Node)) {
97
+ throw ("Can't initialize VanillaTilt because " + element + " is not a Node.");
98
+ }
99
+
100
+ this.width = null;
101
+ this.height = null;
102
+ this.left = null;
103
+ this.top = null;
104
+ this.transitionTimeout = null;
105
+ this.updateCall = null;
106
+
107
+ this.updateBind = this.update.bind(this);
108
+ this.resetBind = this.reset.bind(this);
109
+
110
+ this.element = element as VanillaTilt.HTMLVanillaTiltElement;
111
+ this.settings = this.extendSettings(settings);
112
+
113
+ this.reverse = this.settings.reverse ? -1 : 1;
114
+
115
+ this.glare = this.isSettingTrue(this.settings.glare);
116
+ this.glarePrerender = this.isSettingTrue(this.settings["glare-prerender"]);
117
+
118
+ if (this.glare) {
119
+ this.prepareGlare();
120
+ }
121
+
122
+ this.addEventListeners();
123
+ }
124
+
125
+ isSettingTrue(setting) {
126
+ return setting === "" || setting === true || setting === 1;
127
+ }
128
+
129
+ addEventListeners() {
130
+ this.onMouseEnterBind = this.onMouseEnter.bind(this);
131
+ this.onMouseMoveBind = this.onMouseMove.bind(this);
132
+ this.onMouseLeaveBind = this.onMouseLeave.bind(this);
133
+ this.onWindowResizeBind = this.onWindowResizeBind.bind(this);
134
+
135
+ this.element.addEventListener("mouseenter", this.onMouseEnterBind);
136
+ this.element.addEventListener("mousemove", this.onMouseMoveBind);
137
+ this.element.addEventListener("mouseleave", this.onMouseLeaveBind);
138
+ if (this.glare) {
139
+ window.addEventListener("resize", this.onWindowResizeBind);
140
+ }
141
+ }
142
+
143
+
144
+ onMouseEnter(event) {
145
+ this.updateElementPosition();
146
+ (<any>this.element.style).willChange = "transform";
147
+ this.setTransition();
148
+ }
149
+
150
+ onMouseMove(event) {
151
+ if (this.updateCall !== null) {
152
+ cancelAnimationFrame(this.updateCall);
153
+ }
154
+
155
+ this.event = event;
156
+ this.updateCall = requestAnimationFrame(this.updateBind);
157
+ }
158
+
159
+ onMouseLeave(event) {
160
+ this.setTransition();
161
+
162
+ if (this.settings.reset) {
163
+ requestAnimationFrame(this.resetBind);
164
+ }
165
+ }
166
+
167
+ reset() {
168
+ this.event = {
169
+ pageX: this.left! + this.width! / 2,
170
+ pageY: this.top! + this.height! / 2
171
+ } as MouseEvent;
172
+
173
+ this.element.style.transform = "perspective(" + this.settings.perspective + "px) " +
174
+ "rotateX(0deg) " +
175
+ "rotateY(0deg) " +
176
+ "scale3d(1, 1, 1)"
177
+ ;
178
+
179
+ if (this.glare) {
180
+ this.glareElement.style.transform = 'rotate(180deg) translate(-50%, -50%)';
181
+ this.glareElement.style.opacity = '0';
182
+ }
183
+ }
184
+
185
+ getValues() {
186
+ let x = (this.event.clientX - this.left!) / this.width!;
187
+ let y = (this.event.clientY - this.top!) / this.height!;
188
+
189
+ x = Math.min(Math.max(x, 0), 1);
190
+ y = Math.min(Math.max(y, 0), 1);
191
+
192
+ let tiltX = (this.reverse * (this.settings.max! / 2 - x * this.settings.max!)).toFixed(2);
193
+ let tiltY = (this.reverse * (y * this.settings.max! - this.settings.max! / 2)).toFixed(2);
194
+ let angle = Math.atan2(this.event.clientX - (this.left! + this.width! / 2), -(this.event.clientY - (this.top! + this.height! / 2))) * (180 / Math.PI);
195
+
196
+ return {
197
+ tiltX: tiltX,
198
+ tiltY: tiltY,
199
+ percentageX: x * 100,
200
+ percentageY: y * 100,
201
+ angle: angle
202
+ };
203
+ }
204
+
205
+ updateElementPosition() {
206
+ let rect = this.element.getBoundingClientRect();
207
+
208
+ this.width = this.element.offsetWidth;
209
+ this.height = this.element.offsetHeight;
210
+ this.left = rect.left;
211
+ this.top = rect.top;
212
+ }
213
+
214
+ update() {
215
+ const values = this.getValues();
216
+
217
+ this.element.style.transform = [
218
+ "perspective(" + this.settings.perspective + "px) ",
219
+ "rotateX(" + (this.settings.axis === "x" ? 0 : values.tiltY) + "deg) ",
220
+ "rotateY(" + (this.settings.axis === "y" ? 0 : values.tiltX) + "deg) ",
221
+ "scale3d(" + this.settings.scale + ", " + this.settings.scale + ", " + this.settings.scale + ")",
222
+ ].join(" ");
223
+
224
+ if (this.glare) {
225
+ this.glareElement.style.transform = `rotate(${values.angle}deg) translate(-50%, -50%)`;
226
+ this.glareElement.style.opacity = `${values.percentageY * this.settings["max-glare"]! / 100}`;
227
+ }
228
+
229
+ this.element.dispatchEvent(new CustomEvent("tiltChange", {
230
+ "detail": values
231
+ }));
232
+
233
+ this.updateCall = null;
234
+ }
235
+
236
+ /**
237
+ * Appends the glare element (if glarePrerender equals false)
238
+ * and sets the default style
239
+ */
240
+ prepareGlare() {
241
+ // If option pre-render is enabled we assume all html/css is present for an optimal glare effect.
242
+ if (!this.glarePrerender) {
243
+ // Create glare element
244
+ const jsTiltGlare = document.createElement("div");
245
+ jsTiltGlare.classList.add("js-tilt-glare");
246
+
247
+ const jsTiltGlareInner = document.createElement("div");
248
+ jsTiltGlareInner.classList.add("js-tilt-glare-inner");
249
+
250
+ jsTiltGlare.appendChild(jsTiltGlareInner);
251
+ this.element.appendChild(jsTiltGlare);
252
+ }
253
+
254
+ this.glareElementWrapper = this.element.querySelector(".js-tilt-glare") as HTMLElement;
255
+ this.glareElement = this.element.querySelector(".js-tilt-glare-inner") as HTMLElement;
256
+
257
+ if (this.glarePrerender) {
258
+ return ;
259
+ }
260
+
261
+ Object.assign(this.glareElementWrapper.style, {
262
+ "position": "absolute",
263
+ "top": "0",
264
+ "left": "0",
265
+ "width": "100%",
266
+ "height": "100%",
267
+ "overflow": "hidden",
268
+ 'pointer-events': 'none',
269
+ });
270
+
271
+ Object.assign(this.glareElement.style, {
272
+ 'position': 'absolute',
273
+ 'top': '50%',
274
+ 'left': '50%',
275
+ 'pointer-events': 'none',
276
+ 'background-image': `linear-gradient(0deg, rgba(255,255,255,0) 0%, rgba(255,255,255,1) 100%)`,
277
+ 'width': `${this.element.offsetWidth * 2}px`,
278
+ 'height': `${this.element.offsetWidth * 2}px`,
279
+ 'transform': 'rotate(180deg) translate(-50%, -50%)',
280
+ 'transform-origin': '0% 0%',
281
+ 'opacity': '0',
282
+ });
283
+ }
284
+
285
+ updateGlareSize() {
286
+ Object.assign(this.glareElement.style, {
287
+ 'width': `${this.element.offsetWidth * 2}`,
288
+ 'height': `${this.element.offsetWidth * 2}`,
289
+ });
290
+ }
291
+
292
+ onWindowResizeBind() {
293
+ this.updateGlareSize();
294
+ }
295
+
296
+ setTransition() {
297
+ if (this.transitionTimeout) {
298
+ clearTimeout(this.transitionTimeout);
299
+ }
300
+ // this.element.style.transition = `${this.settings.speed}ms ${this.settings.easing}`;
301
+ /// From openai:
302
+ this.element.style.transition = `transform .4s cubic-bezier(0,0,.2,1)`;
303
+ if (this.glare) {
304
+ this.glareElement.style.transition = `opacity ${this.settings.speed}ms ${this.settings.easing}`;
305
+ }
306
+
307
+ this.transitionTimeout = setTimeout(() => {
308
+ this.element.style.transition = "";
309
+ if (this.glare) {
310
+ this.glareElement.style.transition = "";
311
+ }
312
+ }, this.settings.speed);
313
+
314
+ }
315
+
316
+ extendSettings(settings) {
317
+ let defaultSettings = {
318
+ reverse: false,
319
+ max: 35,
320
+ perspective: 1000,
321
+ easing: "cubic-bezier(.03,.98,.52,.99)",
322
+ scale: "1",
323
+ speed: "300",
324
+ transition: true,
325
+ axis: null,
326
+ glare: false,
327
+ "max-glare": 1,
328
+ "glare-prerender": false,
329
+ reset: true,
330
+ };
331
+
332
+ let newSettings = {};
333
+ for (var property in defaultSettings) {
334
+ if (property in settings) {
335
+ newSettings[property] = settings[property];
336
+ } else if (this.element.hasAttribute("data-tilt-" + property)) {
337
+ let attribute = this.element.getAttribute("data-tilt-" + property);
338
+ try {
339
+ newSettings[property] = JSON.parse(<any>attribute);
340
+ } catch (e) {
341
+ newSettings[property] = attribute;
342
+ }
343
+ } else {
344
+ newSettings[property] = defaultSettings[property];
345
+ }
346
+ }
347
+
348
+ return newSettings;
349
+ }
350
+
351
+ static init(elements, settings: VanillaTilt.TiltOptions = {}) {
352
+ if (elements instanceof Node) {
353
+ elements = [elements];
354
+ }
355
+
356
+ if (elements instanceof NodeList) {
357
+ elements = [].slice.call(elements);
358
+ }
359
+
360
+ if (!(elements instanceof Array)) {
361
+ return ;
362
+ }
363
+
364
+ elements.forEach((element) => {
365
+ if (!("vanillaTilt" in element)) {
366
+ element.vanillaTilt = new VanillaTilt(element, settings);
367
+ }
368
+ });
369
+ }
370
+ }
371
+
front/less/mixins/bfc.less ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .bfc {
2
+ overflow: hidden;
3
+ }