hgrif commited on
Commit
a3cb5cc
β€’
1 Parent(s): 4085796

Split in app & library

Browse files
app.py CHANGED
@@ -1,17 +1,10 @@
1
  import copy
2
- import functools
3
- import itertools
4
- import logging
5
- import random
6
- import string
7
- from typing import List, Optional
8
 
9
- import requests
10
- import numpy as np
11
- import tensorflow as tf
12
  import streamlit as st
13
- from gazpacho import Soup, get
14
- from transformers import AutoTokenizer, TFAutoModelForMaskedLM
 
 
15
 
16
 
17
  DEFAULT_QUERY = "Machines will take over the world soon"
@@ -100,300 +93,6 @@ def display_output(status_text, query, current_sentences, previous_sentences):
100
  query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
101
  )
102
 
103
- class TokenWeighter:
104
- def __init__(self, tokenizer):
105
- self.tokenizer_ = tokenizer
106
- self.proba = self.get_token_proba()
107
-
108
- def get_token_proba(self):
109
- valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
110
- return valid_token_mask
111
-
112
- def _filter_short_partial(self, vocab):
113
- valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
114
- is_valid = np.zeros(len(vocab.keys()))
115
- is_valid[valid_token_ids] = 1
116
- return is_valid
117
-
118
-
119
- class RhymeGenerator:
120
- def __init__(
121
- self,
122
- model: TFAutoModelForMaskedLM,
123
- tokenizer: AutoTokenizer,
124
- token_weighter: TokenWeighter = None,
125
- ):
126
- """Generate rhymes.
127
-
128
- Parameters
129
- ----------
130
- model : Model for masked language modelling
131
- tokenizer : Tokenizer for model
132
- token_weighter : Class that weighs tokens
133
- """
134
-
135
- self.model = model
136
- self.tokenizer = tokenizer
137
- if token_weighter is None:
138
- token_weighter = TokenWeighter(tokenizer)
139
- self.token_weighter = token_weighter
140
- self._logger = logging.getLogger(__name__)
141
-
142
- self.tokenized_rhymes_ = None
143
- self.position_probas_ = None
144
-
145
- # Easy access.
146
- self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0]
147
- self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0]
148
- self.mask_token_id = self.tokenizer.mask_token_id
149
-
150
- def start(self, query: str, rhyme_words: List[str]) -> None:
151
- """Start the sentence generator.
152
-
153
- Parameters
154
- ----------
155
- query : Seed sentence
156
- rhyme_words : Rhyme words for next sentence
157
- """
158
- # TODO: What if no content?
159
- self._logger.info("Got sentence %s", query)
160
- tokenized_rhymes = [
161
- self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words
162
- ]
163
- # Make same length.
164
- self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences(
165
- tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id
166
- )
167
- p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id
168
- self.position_probas_ = p / p.sum(1).reshape(-1, 1)
169
-
170
- def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]:
171
- """Initialize the rhymes.
172
-
173
- * Tokenize input
174
- * Append a comma if the sentence does not end in it (might add better predictions as it
175
- shows the two sentence parts are related)
176
- * Make second line as long as the original
177
- * Add a period
178
-
179
- Parameters
180
- ----------
181
- query : First line
182
- rhyme_word : Last word for second line
183
-
184
- Returns
185
- -------
186
- Tokenized rhyme lines
187
- """
188
-
189
- query_token_ids = self.tokenizer.encode(query, add_special_tokens=False)
190
- rhyme_word_token_ids = self.tokenizer.encode(
191
- rhyme_word, add_special_tokens=False
192
- )
193
-
194
- if query_token_ids[-1] != self.comma_token_id:
195
- query_token_ids.append(self.comma_token_id)
196
-
197
- magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma
198
- return (
199
- query_token_ids
200
- + [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction)
201
- + rhyme_word_token_ids
202
- + [self.period_token_id]
203
- )
204
-
205
- def mutate(self):
206
- """Mutate the current rhymes.
207
-
208
- Returns
209
- -------
210
- Mutated rhymes
211
- """
212
- self.tokenized_rhymes_ = self._mutate(
213
- self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba
214
- )
215
-
216
- rhymes = []
217
- for i in range(len(self.tokenized_rhymes_)):
218
- rhymes.append(
219
- self.tokenizer.convert_tokens_to_string(
220
- self.tokenizer.convert_ids_to_tokens(
221
- self.tokenized_rhymes_[i], skip_special_tokens=True
222
- )
223
- )
224
- )
225
- return rhymes
226
-
227
- def _mutate(
228
- self,
229
- tokenized_rhymes: np.ndarray,
230
- position_probas: np.ndarray,
231
- token_id_probas: np.ndarray,
232
- ) -> np.ndarray:
233
-
234
- replacements = []
235
- for i in range(tokenized_rhymes.shape[0]):
236
- mask_idx, masked_token_ids = self._mask_token(
237
- tokenized_rhymes[i], position_probas[i]
238
- )
239
- tokenized_rhymes[i] = masked_token_ids
240
- replacements.append(mask_idx)
241
-
242
- predictions = self._predict_masked_tokens(tokenized_rhymes)
243
-
244
- for i, token_ids in enumerate(tokenized_rhymes):
245
- replace_ix = replacements[i]
246
- token_ids[replace_ix] = self._draw_replacement(
247
- predictions[i], token_id_probas, replace_ix
248
- )
249
- tokenized_rhymes[i] = token_ids
250
-
251
- return tokenized_rhymes
252
-
253
- def _mask_token(self, token_ids, position_probas):
254
- """Mask line and return index to update."""
255
- token_ids = self._mask_repeats(token_ids, position_probas)
256
- ix = self._locate_mask(token_ids, position_probas)
257
- token_ids[ix] = self.mask_token_id
258
- return ix, token_ids
259
-
260
- def _locate_mask(self, token_ids, position_probas):
261
- """Update masks or a random token."""
262
- if self.mask_token_id in token_ids:
263
- # Already masks present, just return the last.
264
- # We used to return thee first but this returns worse predictions.
265
- return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1]
266
- return np.random.choice(range(len(position_probas)), p=position_probas)
267
-
268
- def _mask_repeats(self, token_ids, position_probas):
269
- """Repeated tokens are generally of less quality."""
270
- repeats = [
271
- ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1]
272
- ]
273
- for ii in repeats:
274
- if position_probas[ii] > 0:
275
- token_ids[ii] = self.mask_token_id
276
- if position_probas[ii + 1] > 0:
277
- token_ids[ii + 1] = self.mask_token_id
278
- return token_ids
279
-
280
- def _predict_masked_tokens(self, tokenized_rhymes):
281
- return self.model(tf.constant(tokenized_rhymes))[0]
282
-
283
- def _draw_replacement(self, predictions, token_probas, replace_ix):
284
- """Get probability, weigh and draw."""
285
- # TODO (HG): Can't we softmax when calling the model?
286
- probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas
287
- probas /= probas.sum()
288
- return np.random.choice(range(len(probas)), p=probas)
289
-
290
-
291
-
292
- def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]:
293
- """Returns a list of rhyme words for a sentence.
294
-
295
- Parameters
296
- ----------
297
- sentence : Sentence that may end with punctuation
298
- n_rhymes : Maximum number of rhymes to return
299
-
300
- Returns
301
- -------
302
- List[str] -- List of words that rhyme with the final word
303
- """
304
- last_word = find_last_word(sentence)
305
- if language == "english":
306
- return query_datamuse_api(last_word, n_rhymes)
307
- elif language == "dutch":
308
- return mick_rijmwoordenboek(last_word, n_rhymes)
309
- else:
310
- raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.")
311
-
312
-
313
- def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]:
314
- """Query the DataMuse API.
315
-
316
- Parameters
317
- ----------
318
- word : Word to rhyme with
319
- n_rhymes : Max rhymes to return
320
-
321
- Returns
322
- -------
323
- Rhyme words
324
- """
325
- out = requests.get(
326
- "https://api.datamuse.com/words", params={"rel_rhy": word}
327
- ).json()
328
- words = [_["word"] for _ in out]
329
- if n_rhymes is None:
330
- return words
331
- return words[:n_rhymes]
332
-
333
-
334
- @functools.lru_cache(maxsize=128, typed=False)
335
- def mick_rijmwoordenboek(word: str, n_words: int):
336
- url = f"https://rijmwoordenboek.nl/rijm/{word}"
337
- html = get(url)
338
- soup = Soup(html)
339
-
340
- results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br />")
341
-
342
- # clean up
343
- results = [r.replace("\n", "").replace(" ", "") for r in results]
344
-
345
- # filter html and empty strings
346
- results = [r for r in results if ("<" not in r) and (len(r) > 0)]
347
-
348
- return random.sample(results, min(len(results), n_words))
349
-
350
-
351
- def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
352
- """Color new words in strings with a span."""
353
-
354
- def find_diff(new_, old_):
355
- return [ii for ii, (n, o) in enumerate(zip(new_, old_)) if n != o]
356
-
357
- new_words = new.split()
358
- old_words = old.split()
359
- forward = find_diff(new_words, old_words)
360
- backward = find_diff(new_words[::-1], old_words[::-1])
361
-
362
- if not forward or not backward:
363
- # No difference
364
- return new
365
-
366
- start, end = forward[0], len(new_words) - backward[0]
367
- return (
368
- " ".join(new_words[:start])
369
- + " "
370
- + f'<span style="background-color: {color}">'
371
- + " ".join(new_words[start:end])
372
- + "</span>"
373
- + " "
374
- + " ".join(new_words[end:])
375
- )
376
-
377
-
378
- def find_last_word(s):
379
- """Find the last word in a string."""
380
- # Note: will break on \n, \r, etc.
381
- alpha_only_sentence = "".join([c for c in s if (c.isalpha() or (c == " "))]).strip()
382
- return alpha_only_sentence.split()[-1]
383
-
384
-
385
- def pairwise(iterable):
386
- """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
387
- # https://stackoverflow.com/questions/5434891/iterate-a-list-as-pair-current-next-in-python
388
- a, b = itertools.tee(iterable)
389
- next(b, None)
390
- return zip(a, b)
391
-
392
-
393
- def sanitize(s):
394
- """Remove punctuation from a string."""
395
- return s.translate(str.maketrans("", "", string.punctuation))
396
-
397
 
398
  if __name__ == "__main__":
399
  main()
 
1
  import copy
 
 
 
 
 
 
2
 
 
 
 
3
  import streamlit as st
4
+ from rhyme_with_ai.rhyme import query_rhyme_words
5
+ from rhyme_with_ai.rhyme_generator import RhymeGenerator
6
+ from rhyme_with_ai.utils import color_new_words, sanitize
7
+ from transformers import TFAutoModelForMaskedLM, AutoTokenizer
8
 
9
 
10
  DEFAULT_QUERY = "Machines will take over the world soon"
 
93
  query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
94
  )
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
  main()
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gazpacho
2
  numpy
3
  requests
4
  tensorflow
5
- transformers
 
 
2
  numpy
3
  requests
4
  tensorflow
5
+ transformers
6
+ -e .
setup.cfg ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [aliases]
2
+ test=pytest
3
+
4
+ [flake8]
5
+ max-line-length = 88
6
+
7
+ [tool:pytest]
8
+ addopts = --cov=src --cov-report=xml:test-coverage.xml --nunitxml test-output.xml -vv
9
+
10
+ [bumpversion]
11
+ current_version = 0.1
12
+ commit = True
13
+ tag = True
14
+
15
+ [bumpversion:file:setup.py]
16
+ search = version='{current_version}'
17
+ replace = version='{new_version}'
setup.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup, find_packages
3
+
4
+ with open("README.md") as readme_file:
5
+ readme = readme_file.read()
6
+
7
+ requirements = [
8
+ "numpy",
9
+ "pandas",
10
+ "requests",
11
+ "tensorflow",
12
+ "transformers",
13
+ ]
14
+
15
+ extra_requirements = {
16
+ "dev": [
17
+ "black",
18
+ "bump2version",
19
+ "coverage",
20
+ "gazpacho",
21
+ "twine",
22
+ "pre-commit",
23
+ "pylint",
24
+ "pytest",
25
+ ]
26
+ }
27
+
28
+ setup_requirements = ["pytest-runner"]
29
+
30
+ test_requirements = ["pytest", "pytest-cov", "pytest-nunit"]
31
+
32
+ BUILD_ID = os.environ.get("BUILD_BUILDID", "0")
33
+
34
+ setup(
35
+ author="Rens Dimmendaal & Henk Griffioen",
36
+ author_email="rensdimmendaal@godatadriven.com",
37
+ classifiers=[
38
+ "Development Status :: 2 - Pre-Alpha",
39
+ "Intended Audience :: Developers",
40
+ "License :: OSI Approved :: MIT License",
41
+ "Natural Language :: English",
42
+ "Programming Language :: Python :: 3.7",
43
+ ],
44
+ description="Generate text",
45
+ install_requires=requirements,
46
+ extras_require=extra_requirements,
47
+ long_description=readme,
48
+ include_package_data=True,
49
+ keywords="rhyme",
50
+ name="rhyme_with_ai",
51
+ packages=find_packages(include=["src"]),
52
+ package_dir={"": "src"},
53
+ setup_requires=setup_requirements,
54
+ test_suite="tests",
55
+ tests_require=test_requirements,
56
+ version="0.1" + "." + BUILD_ID,
57
+ zip_safe=False,
58
+ )
src/rhyme_with_ai/__init__.py ADDED
File without changes
src/rhyme_with_ai/rhyme.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import random
3
+ from typing import List, Optional
4
+
5
+ import requests
6
+ from gazpacho import Soup, get
7
+
8
+ from rhyme_with_ai.utils import find_last_word
9
+
10
+
11
+ def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]:
12
+ """Returns a list of rhyme words for a sentence.
13
+
14
+ Parameters
15
+ ----------
16
+ sentence : Sentence that may end with punctuation
17
+ n_rhymes : Maximum number of rhymes to return
18
+
19
+ Returns
20
+ -------
21
+ List[str] -- List of words that rhyme with the final word
22
+ """
23
+ last_word = find_last_word(sentence)
24
+ if language == "english":
25
+ return query_datamuse_api(last_word, n_rhymes)
26
+ elif language == "dutch":
27
+ return mick_rijmwoordenboek(last_word, n_rhymes)
28
+ else:
29
+ raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.")
30
+
31
+
32
+ def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]:
33
+ """Query the DataMuse API.
34
+
35
+ Parameters
36
+ ----------
37
+ word : Word to rhyme with
38
+ n_rhymes : Max rhymes to return
39
+
40
+ Returns
41
+ -------
42
+ Rhyme words
43
+ """
44
+ out = requests.get(
45
+ "https://api.datamuse.com/words", params={"rel_rhy": word}
46
+ ).json()
47
+ words = [_["word"] for _ in out]
48
+ if n_rhymes is None:
49
+ return words
50
+ return words[:n_rhymes]
51
+
52
+
53
+ @functools.lru_cache(maxsize=128, typed=False)
54
+ def mick_rijmwoordenboek(word: str, n_words: int):
55
+ url = f"https://rijmwoordenboek.nl/rijm/{word}"
56
+ html = get(url)
57
+ soup = Soup(html)
58
+
59
+ results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br />")
60
+
61
+ # clean up
62
+ results = [r.replace("\n", "").replace(" ", "") for r in results]
63
+
64
+ # filter html and empty strings
65
+ results = [r for r in results if ("<" not in r) and (len(r) > 0)]
66
+
67
+ return random.sample(results, min(len(results), n_words))
src/rhyme_with_ai/rhyme_generator.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from transformers import TFAutoModelForMaskedLM, AutoTokenizer
7
+
8
+ from rhyme_with_ai.utils import pairwise
9
+ from rhyme_with_ai.token_weighter import TokenWeighter
10
+
11
+
12
+ class RhymeGenerator:
13
+ def __init__(
14
+ self,
15
+ model: TFAutoModelForMaskedLM,
16
+ tokenizer: AutoTokenizer,
17
+ token_weighter: TokenWeighter = None,
18
+ ):
19
+ """Generate rhymes.
20
+
21
+ Parameters
22
+ ----------
23
+ model : Model for masked language modelling
24
+ tokenizer : Tokenizer for model
25
+ token_weighter : Class that weighs tokens
26
+ """
27
+
28
+ self.model = model
29
+ self.tokenizer = tokenizer
30
+ if token_weighter is None:
31
+ token_weighter = TokenWeighter(tokenizer)
32
+ self.token_weighter = token_weighter
33
+ self._logger = logging.getLogger(__name__)
34
+
35
+ self.tokenized_rhymes_ = None
36
+ self.position_probas_ = None
37
+
38
+ # Easy access.
39
+ self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0]
40
+ self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0]
41
+ self.mask_token_id = self.tokenizer.mask_token_id
42
+
43
+ def start(self, query: str, rhyme_words: List[str]) -> None:
44
+ """Start the sentence generator.
45
+
46
+ Parameters
47
+ ----------
48
+ query : Seed sentence
49
+ rhyme_words : Rhyme words for next sentence
50
+ """
51
+ # TODO: What if no content?
52
+ self._logger.info("Got sentence %s", query)
53
+ tokenized_rhymes = [
54
+ self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words
55
+ ]
56
+ # Make same length.
57
+ self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences(
58
+ tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id
59
+ )
60
+ p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id
61
+ self.position_probas_ = p / p.sum(1).reshape(-1, 1)
62
+
63
+ def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]:
64
+ """Initialize the rhymes.
65
+
66
+ * Tokenize input
67
+ * Append a comma if the sentence does not end in it (might add better predictions as it
68
+ shows the two sentence parts are related)
69
+ * Make second line as long as the original
70
+ * Add a period
71
+
72
+ Parameters
73
+ ----------
74
+ query : First line
75
+ rhyme_word : Last word for second line
76
+
77
+ Returns
78
+ -------
79
+ Tokenized rhyme lines
80
+ """
81
+
82
+ query_token_ids = self.tokenizer.encode(query, add_special_tokens=False)
83
+ rhyme_word_token_ids = self.tokenizer.encode(
84
+ rhyme_word, add_special_tokens=False
85
+ )
86
+
87
+ if query_token_ids[-1] != self.comma_token_id:
88
+ query_token_ids.append(self.comma_token_id)
89
+
90
+ magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma
91
+ return (
92
+ query_token_ids
93
+ + [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction)
94
+ + rhyme_word_token_ids
95
+ + [self.period_token_id]
96
+ )
97
+
98
+ def mutate(self):
99
+ """Mutate the current rhymes.
100
+
101
+ Returns
102
+ -------
103
+ Mutated rhymes
104
+ """
105
+ self.tokenized_rhymes_ = self._mutate(
106
+ self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba
107
+ )
108
+
109
+ rhymes = []
110
+ for i in range(len(self.tokenized_rhymes_)):
111
+ rhymes.append(
112
+ self.tokenizer.convert_tokens_to_string(
113
+ self.tokenizer.convert_ids_to_tokens(
114
+ self.tokenized_rhymes_[i], skip_special_tokens=True
115
+ )
116
+ )
117
+ )
118
+ return rhymes
119
+
120
+ def _mutate(
121
+ self,
122
+ tokenized_rhymes: np.ndarray,
123
+ position_probas: np.ndarray,
124
+ token_id_probas: np.ndarray,
125
+ ) -> np.ndarray:
126
+
127
+ replacements = []
128
+ for i in range(tokenized_rhymes.shape[0]):
129
+ mask_idx, masked_token_ids = self._mask_token(
130
+ tokenized_rhymes[i], position_probas[i]
131
+ )
132
+ tokenized_rhymes[i] = masked_token_ids
133
+ replacements.append(mask_idx)
134
+
135
+ predictions = self._predict_masked_tokens(tokenized_rhymes)
136
+
137
+ for i, token_ids in enumerate(tokenized_rhymes):
138
+ replace_ix = replacements[i]
139
+ token_ids[replace_ix] = self._draw_replacement(
140
+ predictions[i], token_id_probas, replace_ix
141
+ )
142
+ tokenized_rhymes[i] = token_ids
143
+
144
+ return tokenized_rhymes
145
+
146
+ def _mask_token(self, token_ids, position_probas):
147
+ """Mask line and return index to update."""
148
+ token_ids = self._mask_repeats(token_ids, position_probas)
149
+ ix = self._locate_mask(token_ids, position_probas)
150
+ token_ids[ix] = self.mask_token_id
151
+ return ix, token_ids
152
+
153
+ def _locate_mask(self, token_ids, position_probas):
154
+ """Update masks or a random token."""
155
+ if self.mask_token_id in token_ids:
156
+ # Already masks present, just return the last.
157
+ # We used to return thee first but this returns worse predictions.
158
+ return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1]
159
+ return np.random.choice(range(len(position_probas)), p=position_probas)
160
+
161
+ def _mask_repeats(self, token_ids, position_probas):
162
+ """Repeated tokens are generally of less quality."""
163
+ repeats = [
164
+ ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1]
165
+ ]
166
+ for ii in repeats:
167
+ if position_probas[ii] > 0:
168
+ token_ids[ii] = self.mask_token_id
169
+ if position_probas[ii + 1] > 0:
170
+ token_ids[ii + 1] = self.mask_token_id
171
+ return token_ids
172
+
173
+ def _predict_masked_tokens(self, tokenized_rhymes):
174
+ return self.model(tf.constant(tokenized_rhymes))[0]
175
+
176
+ def _draw_replacement(self, predictions, token_probas, replace_ix):
177
+ """Get probability, weigh and draw."""
178
+ # TODO (HG): Can't we softmax when calling the model?
179
+ probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas
180
+ probas /= probas.sum()
181
+ return np.random.choice(range(len(probas)), p=probas)
src/rhyme_with_ai/token_weighter.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class TokenWeighter:
5
+ def __init__(self, tokenizer):
6
+ self.tokenizer_ = tokenizer
7
+ self.proba = self.get_token_proba()
8
+
9
+ def get_token_proba(self):
10
+ valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
11
+ return valid_token_mask
12
+
13
+ def _filter_short_partial(self, vocab):
14
+ valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
15
+ is_valid = np.zeros(len(vocab.keys()))
16
+ is_valid[valid_token_ids] = 1
17
+ return is_valid
src/rhyme_with_ai/utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import string
3
+
4
+
5
+ def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
6
+ """Color new words in strings with a span."""
7
+
8
+ def find_diff(new_, old_):
9
+ return [ii for ii, (n, o) in enumerate(zip(new_, old_)) if n != o]
10
+
11
+ new_words = new.split()
12
+ old_words = old.split()
13
+ forward = find_diff(new_words, old_words)
14
+ backward = find_diff(new_words[::-1], old_words[::-1])
15
+
16
+ if not forward or not backward:
17
+ # No difference
18
+ return new
19
+
20
+ start, end = forward[0], len(new_words) - backward[0]
21
+ return (
22
+ " ".join(new_words[:start])
23
+ + " "
24
+ + f'<span style="background-color: {color}">'
25
+ + " ".join(new_words[start:end])
26
+ + "</span>"
27
+ + " "
28
+ + " ".join(new_words[end:])
29
+ )
30
+
31
+
32
+ def find_last_word(s):
33
+ """Find the last word in a string."""
34
+ # Note: will break on \n, \r, etc.
35
+ alpha_only_sentence = "".join([c for c in s if (c.isalpha() or (c == " "))]).strip()
36
+ return alpha_only_sentence.split()[-1]
37
+
38
+
39
+ def pairwise(iterable):
40
+ """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
41
+ # https://stackoverflow.com/questions/5434891/iterate-a-list-as-pair-current-next-in-python
42
+ a, b = itertools.tee(iterable)
43
+ next(b, None)
44
+ return zip(a, b)
45
+
46
+
47
+ def sanitize(s):
48
+ """Remove punctuation from a string."""
49
+ return s.translate(str.maketrans("", "", string.punctuation))