jinymusim commited on
Commit
eec25bc
1 Parent(s): 6e97e97

Upload 3 files

Browse files

Metrum validator

Files changed (4) hide show
  1. .gitattributes +1 -0
  2. BPE_validator_1697833311028 +3 -0
  3. poet_utils.py +418 -0
  4. validators.py +259 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ BPE_validator_1697833311028 filter=lfs diff=lfs merge=lfs -text
BPE_validator_1697833311028 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e16e690f718daaf133d6e78e82d416ef62a84db3f76f70592460e53da2f6a8fa
3
+ size 498951742
poet_utils.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Most Common Rhyme Schemas
2
+ RHYME_SCHEMES = ["ABAB", "ABBA",
3
+ "XAXA", "ABCB",
4
+ "AABB", "AABA",
5
+ "AAAA", "AABC",
6
+ 'XXXX', 'AXAX',
7
+ "AABBCC", "AABCCB",
8
+ "ABABCC", 'AABCBC',
9
+ "AAABAB", "ABABXX"
10
+ "ABABCD", "ABABAB",
11
+ "ABABBC", "ABABCB",
12
+ "ABBAAB","AABABB",
13
+ "ABCBBB",'ABCBCD',
14
+ "ABBACC","AABBCD",
15
+ None]
16
+
17
+ NORMAL_SCHEMES = ["ABAB", "ABBA", "AABB", "AABBCC", "ABABCC", "ABBACC", "ABBAAB"]
18
+
19
+ # First 200 Most common endings
20
+ VERSE_ENDS = ['ní', 'ou', 'em', 'la', 'ch', 'ti', 'tí', 'je', 'li', 'al', 'ce', 'ky', 'ku', 'ně', 'jí', 'ly', 'il', 'en', 'né',
21
+ 'lo', 'ne', 'vá', 'ny', 'se', 'na', 'ím', 'st', 'le', 'ný', 'ci', 'mi', 'ka', 'ná', 'lí', 'cí', 'ží', 'čí', 'ám',
22
+ 'hu', 'ho', 'ří', 'dí', 'nu', 'dy', 'ší', 'ví', 'du', 'ta', 'as', 'tě', 'ře', 'ru', 'vé', 'ým', 'at', 'ek', 'el',
23
+ 'te', 'tu', 'ká', 'ji', 'ět', 'ni', 'še', 'vy', 'dá', 'it', 'tá', 'ty', 'lý', 'lá', 'mu', 'va', 'ém', 'ěl', 'no',
24
+ 'že', 'vu', 'ál', 'há', 'ků', 'vý', 'bě', 'hy', 'lé', 'sy', 'me', 'es', 'ra', 'ak', 'ad', 'ry', 'zí', 'et', 'rá',
25
+ 'de', 'vě', 'ři', 'lu', 'át', 'da', 'ko', 'ha', 'té', 'to', 'ed', 'ít', 'ký', 'ši', 'íš', 'sí', 'íc', 'ze', 'si',
26
+ 'be', 'má', 'mě', 'by', 'su', 'tý', 'ej', 'či', 'če', 'my', 'ké', 'án', 'ma', 'ům', 'or', 'nů', 'áš', 'dě', 'ec',
27
+ 'mí', 'ev', 'ád', 'ut', 'am', 'yl', 'ul', 'tů', 'bu', 'ás', 'ba', 'ud', 'ář', 'ie', 'od', 'pí', 'ůj', 'eš', 'hý',
28
+ 'bí', 'íž', 'dé', 'an', 'sa', 've', 'lů', 'ín', 'id', 'in', 'mů', 'di', 'hů', 'ic', 'on', 'eň', 'zy', 'ol', 'vo',
29
+ 'ži', 'sů', 'ík', 'vi', 'oj', 'uk', 'uh', 'oc', 'iž', 'sá', 'ěv', 'dý', 'av', 'iv', 'rů', 'ot', 'py', 'mé', 'um',
30
+ 'zd', 'dů', 'ar', 'rý', 'aň', 'sk', 'ok', 'om', 'už', 'ěk', 'ov', 'er', 'uď', 'bi', 'áz', 'ýt', 'ěm', 'ik', 'eď',
31
+ 'ob', 'ák', 'ůh', 'ár', 'sť', 'ro', 'yt', 'ěj', 'mý', 'us', 'ěn', 'ii', 'hé', 'áj', 'pá', 'íh', 'ih', 'zi', 'bá',
32
+ 'eč', 'ré', 'ír', 'ců', 'uj', 'dl', 'áh', 'ův', 'aj', 'eh', 'éž', 'pu', 'ýš', 'zu', 'im', 're', 'up', 'os', 'ah',
33
+ 'rt', 'mo', 'áň', 'sl', 'íl', 'cy', 'ys', 'hl', 'oh', 'ěz', 'ěs', 'ež', 'ií', 'vů', 'kl', 'az', 'cý', 'pe', 'ěd',
34
+ 'do', 'yn', 'šť', 'ez', 'ůl', 'ub', 'ln', 'yk', 'pý', 'ěc', 'ať', 'já', 'op', 'eb', 'áč', 'ív', 'áv', 'jů', 'sý',
35
+ 'is', ' a', 'iť', 'ěř', 'za', 'uť', 'ěh', 'pě', 'íp', 'áž', 'ěď', 'bů', 'ep', 'iš', 'yš', 'ia', 'pa', 'un', 'ěť',
36
+ 'pů', 'eř', 'tr', 'nt', 'pi', 'tl', 'eť', 'ju', 'oď', 'řů', 'ýr', 'rh', 'ur', 'zý', 'ěž', 'ýn', 'ip', 'bý', 'pé',
37
+ 'íň', 'zů', 'čů', 'uč', 'éb', 'ap', 'ón', 'uř', 'ůr', 'íř', 'ač', 'co', 'íč', 'až', 'ls', 'ůž', 'ěr', 'oč', 'ič',
38
+ 'ař', 'ěš', 'uv', 'ůz', 'oň', 'bé', 'sé', 'yč', 'áť', 'jď', 'ri', 'íť', 'oš', 'ůň', 'ék', 'uc', 'rk', 'bo', 'ýl',
39
+ 'oť', 'íz', 'lh', 'so', 'áb', 'ja', 'ij', 'ůn', 'rv', 'žů', 'ab', 'he', 'íd', 'ér', 'uš', 'ýž', 'fá', 'rs', 'rn',
40
+ 'iz', 'ib', 'ki', 'éd', 'év', 'rd', 'yb', 'oz', 'oř', 'ét', 'ož', 'ga', 'yň', 'rp', 'nd', 'of', 'rť', 'iď', 'ýv',
41
+ 'yz', None]
42
+ # Years to bucket to
43
+ POET_YEARS_BUCKETS = [1800, 1820, 1840, 1860, 1880, 1900, 1920, 1940, 1960, None]
44
+ # Possible Meter Types
45
+ METER_TYPES = ["J","T","D","A","X","Y","N","H","P", None]
46
+ # Translation of Meter to one char types
47
+ METER_TRANSLATE = {
48
+ "J":"J",
49
+ "T":"T",
50
+ "D":"D",
51
+ "A":"A",
52
+ "X":"X",
53
+ "Y":"Y",
54
+ "hexameter": "H",
55
+ "pentameter": "P",
56
+ "N":"N"
57
+ }
58
+ # Tokenizers Special Tokens
59
+ PAD = "<|PAD|>"
60
+ UNK = "<|UNK|>"
61
+ EOS = "<|EOS|>"
62
+ # Basic Characters to consider in rhyme and syllables (43)
63
+ VALID_CHARS = [""," ",'a','á','b','c','č','d','ď','e','é','ě',
64
+ 'f','g','h','i','í','j','k','l','m','n','ň',
65
+ 'o','ó','p','q','r','ř','s','š','t','ť','u',
66
+ 'ú','ů','v','w','x','y','ý','z','ž']
67
+
68
+ import re
69
+ import numpy as np
70
+
71
+ class TextManipulation:
72
+ """Static class for string manipulation methods
73
+
74
+ Returns:
75
+ _type_: str returned by all methods
76
+ """
77
+
78
+ @staticmethod
79
+ def _remove_most_nonchar(raw_text, lower_case=True):
80
+ """Remove most non-alpha non-whitespace characters
81
+
82
+ Args:
83
+ raw_text (str): Text to manipulate
84
+ lower_case (bool, optional): If resulting text should be lowercase. Defaults to True.
85
+
86
+ Returns:
87
+ str: Cleaned up text
88
+ """
89
+ text = re.sub(r'[–\„\“\’\;\:()\]\[\_\*\‘\”\'\-\—\"]+', "", raw_text)
90
+ return text.lower() if lower_case else text
91
+
92
+ @staticmethod
93
+ def _remove_all_nonchar(raw_text):
94
+ """Remove all possible non-alpha characters
95
+
96
+ Args:
97
+ raw_text (str): Text to manipulate
98
+
99
+ Returns:
100
+ str: Cleaned up text
101
+ """
102
+ sub = re.sub(r'([^\w\s]+|[0-9]+)', '', raw_text)
103
+ return sub
104
+
105
+ @staticmethod
106
+ def _year_bucketor(raw_year):
107
+ """Bucketizes year string to boundaries, Bad inputs returns NaN string
108
+
109
+ Args:
110
+ raw_year (str): Year string to bucketize
111
+
112
+ Returns:
113
+ _type_: Bucketized year string
114
+ """
115
+ if TextAnalysis._is_year(raw_year) and raw_year != "NaN":
116
+ year_index = np.argmin(np.abs(np.asarray(POET_YEARS_BUCKETS[:-1]) - int(raw_year)))
117
+ return str(POET_YEARS_BUCKETS[year_index])
118
+ else:
119
+ return "NaN"
120
+
121
+ class TextAnalysis:
122
+ """Static class with methods of analysis of strings
123
+
124
+ Returns:
125
+ Union[str, bool, dict, numpy.ndarray]: Analyzed input
126
+ """
127
+
128
+ # Possible Keys if returned type is dict
129
+ POET_PARAM_LIST = ["RHYME", "YEAR", "METER", "LENGTH", "END", "TRUE_LENGTH", "TRUE_END"]
130
+
131
+ @staticmethod
132
+ def _is_meter(meter:str):
133
+ """Return if string is meter type
134
+
135
+ Args:
136
+ meter (str): string to analyze
137
+
138
+ Returns:
139
+ bool: If string is meter type
140
+ """
141
+ return meter in METER_TYPES[:-1]
142
+
143
+ @staticmethod
144
+ def _is_year(year:str):
145
+ """Return if string is year or special NaN
146
+
147
+ Args:
148
+ year (str): string to analyze
149
+
150
+ Returns:
151
+ bool: If string is year or special NaN
152
+ """
153
+ return (year.isdigit() and int(year) > 1_000 and int(year) < 10_000) or year == "NaN"
154
+
155
+ @staticmethod
156
+ def _rhyme_like(rhyme:str):
157
+ """Return if string is structured like rhyme schema
158
+
159
+ Args:
160
+ rhyme (str): string to analyze
161
+
162
+ Returns:
163
+ bool: If string is structured like rhyme schema
164
+ """
165
+ return (rhyme.isupper() and len(rhyme) >= 3 and len(rhyme) <= 6)
166
+
167
+ @staticmethod
168
+ def _rhyme_vector(rhyme:str) -> np.ndarray:
169
+ """Create One-hot encoded rhyme schema vector from given string
170
+
171
+ Args:
172
+ rhyme (str): string to construct vector from
173
+
174
+ Returns:
175
+ numpy.ndarray: One-hot encoded rhyme schema vector
176
+ """
177
+
178
+ rhyme_vec = np.zeros(len(RHYME_SCHEMES))
179
+ if rhyme in RHYME_SCHEMES:
180
+ rhyme_vec[RHYME_SCHEMES.index(rhyme)] = 1
181
+ else:
182
+ rhyme_vec[-1] = 1
183
+
184
+ return rhyme_vec
185
+
186
+ @staticmethod
187
+ def _rhyme_or_not(rhyme_str:str) -> np.ndarray:
188
+ """Create vector if given rhyme string is in our list of rhyme schemas
189
+
190
+ Args:
191
+ rhyme_str (str): string to construct vector from
192
+
193
+ Returns:
194
+ numpy.ndarray: Boolean flag vector
195
+ """
196
+ rhyme_vector = np.zeros(2)
197
+ if rhyme_str in RHYME_SCHEMES:
198
+ rhyme_vector[0] = 1
199
+ else:
200
+ rhyme_vector[1] = 1
201
+ return rhyme_vector
202
+
203
+ @staticmethod
204
+ def _metre_vector(metre: str) -> np.ndarray:
205
+ """Create One-hot encoded metre vector from given string
206
+
207
+ Args:
208
+ metre (str): string to construct vector from
209
+
210
+ Returns:
211
+ numpy.ndarray: One-hot encoded metre vector
212
+ """
213
+ metre_vec = np.zeros(len(METER_TYPES))
214
+ if metre in METER_TYPES:
215
+ metre_vec[METER_TYPES.index(metre)] = 1
216
+ else:
217
+ metre_vec[-2] = 1
218
+ return metre_vec
219
+
220
+ @staticmethod
221
+ def _first_line_analysis(text:str):
222
+ """Analysis of parameter line for RHYME, METER, YEAR
223
+
224
+ Args:
225
+ text (str): parameter line string
226
+
227
+ Returns:
228
+ dict: Dictionary with analysis result
229
+ """
230
+ line_striped = text.strip()
231
+ if not line_striped:
232
+ return {}
233
+ poet_params = {}
234
+ # Look for each possible parameter
235
+ for param in line_striped.split():
236
+ if TextAnalysis._is_meter(param):
237
+ poet_params["METER"] = param
238
+ elif TextAnalysis._is_year(param):
239
+ # Year is Bucketized so to fit
240
+ poet_params["YEAR"] = TextManipulation._year_bucketor(param)
241
+ elif TextAnalysis._rhyme_like(param):
242
+ poet_params["RHYME"] = param
243
+ return poet_params
244
+
245
+ @staticmethod
246
+ def _is_line_length(length:str):
247
+ """Return if string is number of syllables parameter
248
+
249
+ Args:
250
+ length (str): string to analyze
251
+
252
+ Returns:
253
+ bool: If string is number of syllables parameter
254
+ """
255
+ return length.isdigit() and int(length) > 1 and int(length) < 100
256
+
257
+ @staticmethod
258
+ def _is_line_end(end:str):
259
+ """Return if string is valid ending syllable/sequence parameter
260
+
261
+ Args:
262
+ end (str): string to analyze
263
+
264
+ Returns:
265
+ bool: If string is valid ending syllable/sequence parameter
266
+ """
267
+ return end.isalpha() and len(end) <= 5
268
+
269
+ @staticmethod
270
+ def _continuos_line_analysis(text:str):
271
+ """Analysis of Content lines for LENGTH, TRUE_LENGTH, END, TRUE_END
272
+
273
+ Args:
274
+ text (str): content line to analyze
275
+
276
+ Returns:
277
+ dict: Dictionary with analysis result
278
+ """
279
+ # Strip line of most separators and look if its empty
280
+ line_striped = TextManipulation._remove_most_nonchar(text).strip()
281
+ if not line_striped:
282
+ return {}
283
+ line_params = {}
284
+ # Look for parameters in Order LENGTH, END, TRUE_LENGTH, TRUE_END
285
+ if TextAnalysis._is_line_length(line_striped.split()[0]):
286
+ line_params["LENGTH"] = int(line_striped.split()[0])
287
+ if len(line_striped.split()) > 1 and TextAnalysis._is_line_end(line_striped.split()[1]):
288
+ line_params["END"] = line_striped.split()[1]
289
+ if len(line_striped.split()) > 3:
290
+ line_params["TRUE_LENGTH"] = len(SyllableMaker.syllabify(" ".join(line_striped.split()[3:])))
291
+ # TRUE_END needs only alpha chars, so all other chars are removed
292
+ line_only_char = TextManipulation._remove_all_nonchar(line_striped).strip()
293
+ if len(line_only_char) > 2:
294
+ line_params["TRUE_END"] = SyllableMaker.syllabify(line_only_char)[-1]
295
+
296
+ return line_params
297
+
298
+ @staticmethod
299
+ def _is_param_line(text:str):
300
+ """Return if line is a Parameter line (Parameters RHYME, METER, YEAR)
301
+
302
+ Args:
303
+ text (str): line to analyze
304
+
305
+ Returns:
306
+ bool: If line is a Parameter line
307
+ """
308
+ line_striped = text.strip()
309
+ if not line_striped:
310
+ return False
311
+ small_analysis = TextAnalysis._first_line_analysis(line_striped)
312
+ return "RHYME" in small_analysis.keys() or "METER" in small_analysis.keys() or "YEAR" in small_analysis.keys()
313
+
314
+ # NON-Original code!
315
+ # Taken from Barbora Štěpánková
316
+ class SyllableMaker:
317
+ """Static class with methods for separating string to list of Syllables
318
+
319
+ Returns:
320
+ list: List of syllables
321
+ """
322
+
323
+ @staticmethod
324
+ def syllabify(text : str) -> list[str]:
325
+ words = re.findall(r"[aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzžAÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽäöüÄÜÖ]+", text)
326
+ syllables : list[str] = []
327
+
328
+ i = 0
329
+ while i < len(words):
330
+ word = words[i]
331
+
332
+ if (word.lower() == "k" or word.lower() == "v" or word.lower() == "s" or word.lower() == "z") and i < len(words) - 1 and len(words[i + 1]) > 1:
333
+ i += 1
334
+ word = word + words[i]
335
+
336
+ letter_counter = 0
337
+
338
+ # Get syllables: mask the word and split the mask
339
+ for syllable_mask in SyllableMaker.__split_mask(SyllableMaker.__create_word_mask(word)):
340
+ word_syllable = ""
341
+ for character in syllable_mask:
342
+ word_syllable += word[letter_counter]
343
+ letter_counter += 1
344
+
345
+ syllables.append(word_syllable)
346
+
347
+ i += 1
348
+
349
+ return syllables
350
+
351
+
352
+ @staticmethod
353
+ def __create_word_mask(word : str) -> str:
354
+ word = word.lower()
355
+
356
+ vocals = r"[aeiyouáéěíýóůúäöü]"
357
+ consonants = r"[bcčdďfghjklmnňpqrřsštťvwxzž]"
358
+
359
+ replacements = [
360
+ #double letters
361
+ ('ch', 'c0'),
362
+ ('rr', 'r0'),
363
+ ('ll', 'l0'),
364
+ ('nn', 'n0'),
365
+ ('th', 't0'),
366
+
367
+ # au, ou, ai, oi
368
+ (r'[ao]u', '0V'),
369
+ (r'[ao]i','0V'),
370
+
371
+ # eu at the beginning of the word
372
+ (r'^eu', '0V'),
373
+
374
+ # now all vocals
375
+ (vocals, 'V'),
376
+
377
+ # r,l that act like vocals in syllables
378
+ (r'([^V])([rl])(0*[^0Vrl]|$)', r'\1V\3'),
379
+
380
+ # sp, st, sk, št, Cř, Cl, Cr, Cv
381
+ (r's[pt]', 's0'),
382
+ (r'([^V0lr]0*)[řlrv]', r'\g<1>0'),
383
+ (r'([^V0]0*)sk', r'\1s0'),
384
+ (r'([^V0]0*)št', r'\1š0'),
385
+
386
+ (consonants, 'K')
387
+ ]
388
+
389
+ for (original, replacement) in replacements:
390
+ word = re.sub(original, replacement, word)
391
+
392
+ return word
393
+
394
+
395
+ @staticmethod
396
+ def __split_mask(mask : str) -> list[str]:
397
+ replacements = [
398
+ # vocal at the beginning
399
+ (r'(^0*V)(K0*V)', r'\1/\2'),
400
+ (r'(^0*V0*K0*)K', r'\1/K'),
401
+
402
+ # dividing the middle of the word
403
+ (r'(K0*V(K0*$)?)', r'\1/'),
404
+ (r'/(K0*)K', r'\1/K'),
405
+ (r'/(0*V)(0*K0*V)', r'/\1/\2'),
406
+ (r'/(0*V0*K0*)K', r'/\1/K'),
407
+
408
+ # add the last consonant to the previous syllable
409
+ (r'/(K0*)$', r'\1/')
410
+ ]
411
+
412
+ for (original, replacement) in replacements:
413
+ mask = re.sub(original, replacement, mask)
414
+
415
+ if len(mask) > 0 and mask[-1] == "/":
416
+ mask = mask[0:-1]
417
+
418
+ return mask.split("/")
validators.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ import jellyfish
4
+ from tqdm import tqdm
5
+ from transformers import AutoModelForMaskedLM
6
+ from .poet_utils import RHYME_SCHEMES, METER_TYPES
7
+
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from pytorch_optimizer import SAM,GSAM, ProportionScheduler, AdamP
10
+
11
+ class ValidatorInterface(torch.nn.Module):
12
+ """Pytorch Model Interface. Abstract class for all validators
13
+
14
+ Args:
15
+ torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
16
+ """
17
+ def __init__(self, *args, **kwargs) -> None:
18
+ """ Constructor. As child Class needs to construct Parent
19
+ """
20
+ super().__init__(*args, **kwargs)
21
+
22
+ def forward(self, input_ids=None, attention_mask=None, *args, **kwargs):
23
+ """Compute model output and model loss
24
+
25
+ Args:
26
+ input_ids (_type_, optional): Model inputs. Defaults to None.
27
+ attention_mask (_type_, optional): Attention mask where padding starts. Defaults to None.
28
+
29
+ Raises:
30
+ NotImplementedError: Abstract class
31
+ """
32
+ raise NotImplementedError()
33
+
34
+ def predict(self, input_ids=None, *args, **kwargs):
35
+ """Compute model outputs
36
+
37
+ Args:
38
+ input_ids (_type_, optional): Model inputs. Defaults to None.
39
+
40
+ Raises:
41
+ NotImplementedError: Abstract class
42
+ """
43
+ raise NotImplementedError()
44
+
45
+ def validate(self, input_ids=None, *args, **kwargs):
46
+ """Validate model given some labels, Doesn't use loss
47
+
48
+ Args:
49
+ input_ids (_type_, optional): Model inputs. Defaults to None.
50
+
51
+ Raises:
52
+ NotImplementedError: Abstract class
53
+ """
54
+ raise NotImplementedError()
55
+
56
+
57
+ class RhymeValidator(ValidatorInterface):
58
+ def __init__(self, pretrained_model, *args, **kwargs) -> None:
59
+ super().__init__(*args, **kwargs)
60
+
61
+ self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
62
+
63
+ self.config = self.model.config
64
+
65
+ self.model_size = self.config.hidden_size
66
+
67
+ self.rhyme_regressor = torch.nn.Linear(self.model_size, len(RHYME_SCHEMES)) # Common Rhyme Type
68
+
69
+ self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.05)
70
+
71
+ def forward(self, input_ids=None, attention_mask=None, rhyme=None, *args, **kwargs):
72
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
73
+
74
+ last_hidden = outputs['hidden_states'][-1]
75
+
76
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
77
+
78
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
79
+ rhyme_loss = self.loss_fnc(softmaxed, rhyme)
80
+
81
+ return {"model_output" : softmaxed,
82
+ "loss": rhyme_loss + outputs.loss}
83
+
84
+ def predict(self, input_ids=None, *args, **kwargs):
85
+
86
+ outputs = self.model(input_ids=input_ids)
87
+
88
+ last_hidden = outputs['hidden_states'][-1]
89
+
90
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
91
+
92
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
93
+
94
+ return softmaxed
95
+
96
+ def validate(self, input_ids=None, rhyme=None, k:int = 2,*args, **kwargs):
97
+ outputs = self.model(input_ids=input_ids)
98
+
99
+ last_hidden = outputs['hidden_states'][-1]
100
+
101
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
102
+
103
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
104
+
105
+ softmaxed = softmaxed.flatten()
106
+
107
+ predicted_val = torch.argmax(softmaxed)
108
+
109
+ predicted_top_k = torch.topk(softmaxed, k).indices
110
+
111
+ label_val = torch.argmax(rhyme.flatten())
112
+
113
+ validation_true_val = (label_val == predicted_val).float().sum().numpy()
114
+ top_k_presence = 0
115
+ if label_val in predicted_top_k:
116
+ top_k_presence = 1
117
+
118
+ levenshtein = jellyfish.levenshtein_distance(RHYME_SCHEMES[predicted_val] if RHYME_SCHEMES[predicted_val] != None else "", RHYME_SCHEMES[label_val] if RHYME_SCHEMES[label_val] != None else "")
119
+
120
+ hit_pred = softmaxed[label_val].detach().numpy()
121
+
122
+ return {"acc" : validation_true_val,
123
+ "top_k" : top_k_presence,
124
+ "lev_distance": levenshtein,
125
+ "predicted_label" : hit_pred
126
+ }
127
+
128
+
129
+
130
+ class MeterValidator(ValidatorInterface):
131
+ def __init__(self, pretrained_model, *args, **kwargs) -> None:
132
+ super().__init__(*args, **kwargs)
133
+ self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
134
+
135
+ self.config = self.model.config
136
+
137
+ self.model_size = self.config.hidden_size
138
+
139
+ self.meter_regressor = torch.nn.Linear(self.model_size, len(METER_TYPES)) # Meter Type
140
+
141
+ self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.05)
142
+
143
+ def forward(self, input_ids=None, attention_mask=None, metre=None, *args, **kwargs):
144
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
145
+
146
+ last_hidden = outputs['hidden_states'][-1]
147
+
148
+ meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
149
+
150
+ softmaxed = torch.softmax(meter_regression, dim=1)
151
+ meter_loss = self.loss_fnc(softmaxed, metre)
152
+
153
+ return {"model_output" : softmaxed,
154
+ "loss": meter_loss + outputs.loss}
155
+
156
+ def predict(self, input_ids=None, *args, **kwargs):
157
+ outputs = self.model(input_ids=input_ids)
158
+
159
+ last_hidden = outputs['hidden_states'][-1]
160
+
161
+ meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
162
+
163
+ softmaxed = torch.softmax(meter_regression, dim=1)
164
+
165
+ return softmaxed
166
+
167
+ def validate(self, input_ids=None, metre=None, k: int=2,*args, **kwargs):
168
+ outputs = self.model(input_ids=input_ids)
169
+
170
+ last_hidden = outputs['hidden_states'][-1]
171
+
172
+ meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
173
+
174
+ softmaxed = torch.softmax(meter_regression, dim=1)
175
+
176
+ softmaxed = softmaxed.flatten()
177
+
178
+ predicted_val = torch.argmax(softmaxed)
179
+
180
+ predicted_top_k = torch.topk(softmaxed, k).indices
181
+
182
+ label_val = torch.argmax(metre.flatten())
183
+
184
+ validation_true_val = (label_val == predicted_val).float().sum().numpy()
185
+ top_k_presence = 0
186
+ if label_val in predicted_top_k:
187
+ top_k_presence = 1
188
+
189
+ hit_pred = softmaxed[label_val].detach().numpy()
190
+
191
+ return {"acc" : validation_true_val,
192
+ "top_k" : top_k_presence,
193
+ "predicted_label" : hit_pred
194
+ }
195
+
196
+
197
+ class ValidatorTrainer:
198
+ def __init__(self, model: ValidatorInterface, args: dict, train_dataset: Dataset, data_collator, device):
199
+ self.model = model
200
+ self.args = args
201
+ self.epochs = 1 if "epochs" not in args.keys() else args["epochs"]
202
+ self.batch_size = 1 if "batch_size" not in args.keys() else args["batch_size"]
203
+ self.lr = 3e-4 if "lr" not in args.keys() else args["lr"]
204
+ self.weight_decay = 0.0 if "weight_decay" not in args.keys() else args['weight_decay']
205
+
206
+ self.train_loader = DataLoader(train_dataset, self.batch_size, True, collate_fn=data_collator)
207
+
208
+ # SAM Values
209
+ self.device = device
210
+ self.optimizer = SAM(self.model.parameters(), torch.optim.AdamW, lr=self.lr, weight_decay=self.weight_decay)
211
+ self.scheduler = transformers.get_constant_schedule_with_warmup(self.optimizer, len(train_dataset)//self.batch_size)
212
+
213
+ # GSAM Value
214
+ #self.device = device
215
+ #self.base_optim = AdamP(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
216
+ #self.scheduler = transformers.get_constant_schedule_with_warmup(self.base_optim, len(train_dataset)//self.batch_size)
217
+ #self.rho_scheduler= ProportionScheduler( self.scheduler, max_lr=self.lr)
218
+ #self.optimizer = GSAM(self.model.parameters(),self.base_optim, self.model, self.rho_scheduler, alpha=0.05)
219
+
220
+ def train(self):
221
+ for epoch in tqdm(range(self.epochs)):
222
+ self.model.train()
223
+
224
+ # SAM Attempt
225
+
226
+ for step, batch in enumerate(self.train_loader):
227
+ # First Pass
228
+ loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
229
+ rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
230
+ metre = None if batch["metre"] == None else batch["metre"].to(self.device))['loss']
231
+ loss.backward()
232
+ self.optimizer.first_step(zero_grad=True)
233
+ # Second Pass
234
+ loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
235
+ rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
236
+ metre = None if batch["metre"] == None else batch["metre"].to(self.device))['loss']
237
+ loss.backward()
238
+ self.optimizer.second_step(zero_grad=True)
239
+ self.scheduler.step()
240
+
241
+ # GSAM Attempt
242
+
243
+ #for step, batch in enumerate(self.train_loader):
244
+ # def closure():
245
+ # self.optimizer.base_optimizer.zero_grad()
246
+ # with torch.enable_grad():
247
+ # outputs = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
248
+ # rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
249
+ # metre = None if batch["metre"] == None else batch["metre"].to(self.device))
250
+ # loss = torch.nn.functional.cross_entropy(outputs['model_output'].to(self.device),batch['rhyme'].to(self.device) if isinstance(self.model, RhymeValidator) else batch['metre'].to(self.device))
251
+ # loss.backward()
252
+ # return outputs['model_output'], loss.detach()
253
+ # predictions, loss = self.optimizer.step(closure)
254
+ # self.scheduler.step()
255
+ # self.optimizer.update_rho_t()
256
+ #
257
+ if step % 100 == 0:
258
+ print(f'Step {step}, loss : {loss.item()}', flush=True)
259
+