File size: 13,649 Bytes
0376b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a4a080
0376b0a
 
2a4a080
0376b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
from typing import List, Union
from pythainlp.tokenize import subword_tokenize,word_tokenize
from pythainlp.util import sound_syllable
from pythainlp.util import remove_tonemark
from pythainlp.khavee import KhaveeVerifier
import pythainlp as pythai
from pythainlp.tokenize import word_tokenize
from pythainlp.tokenize import subword_tokenize
from pythainlp.util import sound_syllable
from pythainlp.util import isthai
from pythainlp.transliterate import pronunciate
from pythainlp.spell import correct
from tqdm import tqdm
import numpy as np
import pandas as pd
kv = KhaveeVerifier()
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Thanravee/KarveeSaimai", local_files_only=False)
model = AutoModelForCausalLM.from_pretrained("Thanravee/KarveeSaimai", local_files_only=False)

# split text from \n to list and drop soi word ->  splitted wak list (no soi)
def split_klong(klong_text):
  splitted_klong = []
  klong_list = klong_text.split('-')
  klong_list = [klong for klong in klong_list if klong.strip()]
  for i in range(len(klong_list)):
    if i == 1 or i == 3 or i == 5:
      klong = klong_list[i]
      if klong[0] == ' ':
        klong = klong[1:]
      klong = klong.split(' ')
      splitted_klong.append(klong[0])
    else:
      splitted_klong.append(klong_list[i].replace(' ', ''))
  return splitted_klong

# subword tokenize wak with ssg and dict
def subword_token(wak, engine='ssg'):
  subword_tokenized = subword_tokenize(wak, engine='ssg')
  if len(subword_tokenized) != 5 and len(subword_tokenized) != 2:
      subword_tokenized = subword_tokenize(wak, engine='dict')
  return subword_tokenized


# check number of syllables -> [True, True, True, True, True, True, True, True] (len=8)
def subword_num(splitted_klong):
  checked = []
  two = [1,3,5]
  five = [0,2,4,6]
  for num in range(len(splitted_klong)):
    if num in two:
      checked.append(len(subword_token(splitted_klong[num])) == 2)
    elif num in five:
      checked.append(len(subword_token(splitted_klong[num])) == 5)
    elif num == 7:
      checked.append(len(subword_token(splitted_klong[num])) == 4)
  return checked

# check what word tone is
def find_tone(word):
  char_list = [*word]
  if "่" in char_list or sound_syllable(word) == 'dead':
    return "eak or dead"
  elif "้" in char_list:
    return "tou"
  else:
    return False

# check eaktou -> list[True, True, True, True, True, True, True, True] (len=8)
def check_eaktou(splitted_klong):
  checked = []
  for num in range(len(splitted_klong)):
    tokenzied_wak = subword_token(splitted_klong[num])
    if num == 0:
      checked.append(find_tone(tokenzied_wak[3]) == "eak or dead" and find_tone(tokenzied_wak[4]) == 'tou')
    elif num == 1:
      checked.append(True)
    elif num == 2:
      checked.append(find_tone(tokenzied_wak[1]) == "eak or dead")
    elif num == 3:
      checked.append(find_tone(tokenzied_wak[0]) == 'eak or dead' and find_tone(tokenzied_wak[1]) == 'tou')
    elif num == 4:
      checked.append(find_tone(tokenzied_wak[2]) == 'eak or dead')
    elif num == 5:
      checked.append(find_tone(tokenzied_wak[1]) == 'eak or dead')
    elif num == 6:
      checked.append(find_tone(tokenzied_wak[1]) == "eak or dead" and find_tone(tokenzied_wak[4]) == 'tou')
    elif num == 7:
      checked.append(find_tone(tokenzied_wak[0]) == "eak or dead" and find_tone(tokenzied_wak[1]) == 'tou')
  return checked

# last sound of wak from pronunciate tokenized last word of each wak
# ex [เสียงลือเสียงเล่าอ้าง] -> [อ้าง]
def sound_words(splitted_klong):
  sound_list = []
  for wak in splitted_klong:
    list_char = [*wak]
    if " " in list_char:
      wak = wak.split(" ")
      wak = wak[0]
    wak = word_tokenize(wak, engine="newmm")
    pronounce_word = pronunciate(wak[-1], engine="w2p")
    sound_list.append(pronounce_word.replace('ฺ', '').split('-')[-1])
  return sound_list

# check sampas -> [True, True, True]
# [0] = sampas wak 2-3, [1] = sampas wak 2-4, [2] sampas wak 4-7
def check_sampas(sound_list):
  checked = []
  if len(sound_list) > 2:
    checked.append(kv.is_sumpus(sound_list[1],sound_list[2]))
    if len(sound_list) > 4:
      checked.append(kv.is_sumpus(sound_list[1],sound_list[4]))
      if len(sound_list) > 6:
        checked.append(kv.is_sumpus(sound_list[3],sound_list[6]))
  else:
    checked.append(True)
  return checked

def main_check(klong_text):
  splitted_klong = split_klong(klong_text)
  checked_subword_num = subword_num(splitted_klong)
  if False in checked_subword_num:
    false_index = checked_subword_num.index(False)
    return 'syllable format error', false_index+1
  else:
    checked_eaktou = check_eaktou(splitted_klong)
    if False in checked_eaktou:
      false_index = checked_eaktou.index(False)
      return 'eaktou format error', false_index+1
    else:
      sound_list = sound_words(splitted_klong)
      checked_sampas = check_sampas(sound_list)
      if False in checked_sampas:
        wak_sampas = ['2 and 3', '2 and 5', '4 and 7']
        return 'sampas format error', wak_sampas[checked_sampas.index(False)]
      else:
        return True
    
    
def gen_prob_next_token(text:str, model, tokenizer):
  input_ids = tokenizer(text, return_tensors="pt")
  #look at tensor shape
  input_ids,input_ids['input_ids'].shape

  #get logit of the next token
  outputs = model(input_ids['input_ids'])
  logits = outputs.logits
  logits.shape #the size is equal to input token because it's predicting the next one

  #convert logit to prob; use the logits of the last input token
  import torch.nn.functional as F
  probs = F.softmax(logits[:, -1, :], dim=-1).squeeze()
  probs, probs.argmax()

  #match prob with vocab
  import pandas as pd
  df = pd.DataFrame(tokenizer.vocab.items(), columns=['token', 'token_id']).sort_values('token_id').reset_index(drop=True)

  df['prob'] = probs.detach().numpy()

  possible_token = df.sort_values('prob',ascending=False).reset_index()
  thai_only = [x if isthai(x) else None for x in possible_token['token']] # thai only
  possible_token['token'] = thai_only
  possible_token = possible_token.dropna()
  return possible_token


# filter broken word and get passed only 100 words
def gen_rules(probs, fast_gen=True):
  passed = []
  limiter = 5 if fast_gen else 100000000
  for prob in probs:
    if len(check_word(prob)) > 1 and len(subword_token(prob)) == 1 and '-' not in pronunciate(prob) and len(passed) <= limiter:
        passed.append(correct(prob))
  return passed


def check_word(word):
  alphabets = [alp for alp in [*word] if alp not in ['่','้','๊','๋','์']]
  if '์' in [*word]:
    alphabets = [*word][:-2]
  return alphabets

def generator(klong):
  prob = gen_prob_next_token(klong, model, tokenizer)
  new_prob = gen_rules(prob['token'].tolist())
  return new_prob

# get word with sampas
def get_sampassed(data:list, sampaswith):
  passed = []
  counter_exception = 0
  for possible_word in tqdm(data):
    possible_sampas = pronunciate(possible_word).split('-')[-1] # reduce word dimension
    sampaswith = pronunciate(sampaswith).split('-')[-1] # reduce word dimension
    try:
      if kv.is_sumpus(possible_sampas, sampaswith):
        passed.append(possible_word)
    except IndexError:
      counter_exception += 1
      continue
  assert len(passed) != counter_exception # if this failed mena that this function skipped all sampass which shouldn't be the case
  return passed

# get word with aek or too
def get_aek_too(data:list, ktype='aek'):
  passed = []
  for possible_word in tqdm(data):
      if kv.check_aek_too(possible_word) == ktype:
        passed.append(possible_word)
  return passed

def tone_gen(klong_text, gened_word, word_mark='no', sampas=False):
  splitted_klong = split_klong(klong_text)
  if word_mark == 'no' and sampas == False:
     probs = generator(klong_text)
     for prob in probs:
       if prob not in gened_word:
         gened_word.append(prob)
         return prob, gened_word
  elif word_mark == 'aek' and sampas == False:
    probs = generator(klong_text)
    aek = get_aek_too(probs)
    for prob in aek:
      if prob not in gened_word:
        gened_word.append(prob)
        return prob, gened_word
  elif word_mark == 'too' and sampas == False:
    probs = generator(klong_text)
    too = get_aek_too(probs, 'too')
    for prob in too:
      if prob not in gened_word:
        gened_word.append(prob)
        return prob, gened_word
  elif sampas == True and word_mark == 'no':
    probs = gen_prob_next_token(klong_text, model, tokenizer)
    probs = probs['token'][:500]
    passed = get_sampassed(probs, sound_words(splitted_klong)[1])
    for prob in passed:
       if prob not in gened_word:
         gened_word.append(prob)
         return prob, gened_word
  elif sampas == True and word_mark == 'too':
    probs = gen_prob_next_token(klong_text, model, tokenizer)
    probs = probs['token'][:500]
    passed = get_sampassed(probs, sound_words(splitted_klong)[3])
    for prob in passed:
       if prob not in gened_word and kv.check_aek_too(prob) == 'too':
         gened_word.append(prob)
         return prob, gened_word
     
def gen_klong(klong_text_input, gened_word):
  splitted_klong = split_klong(klong_text_input)
  klong_text = klong_text_input
  # วรรค 2, 4, 6
  if len(splitted_klong) in [1, 3, 5]:
    word_gen = 2
    if len(splitted_klong) == 1:
      # ฉันทลักษณ์ (none, none(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      klong_text = klong_text + '-'
    elif len(splitted_klong) == 3:
      # ฉันทลักษณ์ (aek, too(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, 'too')
      klong_text = klong_text + prob
      klong_text = klong_text + '-'
    elif len(splitted_klong) == 5:
      # ฉันทลักษณ์ (none, aek)
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      klong_text = klong_text + '-'

  # วรรค 3, 5, 7
  elif len(splitted_klong) in [2, 4, 6]:
    word_gen = 5
    if len(splitted_klong) == 2:
      # ฉันทลักษณ์ (none, aek, none, none, none(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      sampas_word = sound_words(splitted_klong)[1]
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='no', sampas=True)
      klong_text = klong_text + prob
      klong_text = klong_text + '-'
    elif len(splitted_klong) == 4:
      # ฉันทลักษณ์ (none, none, aek, none, none(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      sampas_word = sound_words(splitted_klong)[1]
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='no', sampas=True)
      klong_text = klong_text + prob
      klong_text = klong_text + '-'
    elif len(splitted_klong) == 6:
      # ฉันทลักษณ์ (none, aek, none, none, too(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      sampas_word = sound_words(splitted_klong)[1]
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='too', sampas=True)
      klong_text = klong_text + prob
      klong_text = klong_text + '-'
  # วรรค 8
  elif len(splitted_klong) == 7:
    # ฉันทลักษณ์ (eak, too, none, none)
    word_gen = 4
    prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
    klong_text = klong_text + prob
    prob, gened_word  = tone_gen(klong_text, gened_word, 'too')
    klong_text = klong_text + prob
    prob, gened_word  = tone_gen(klong_text, gened_word)
    klong_text = klong_text + prob
    prob, gened_word  = tone_gen(klong_text, gened_word)
    klong_text = klong_text + prob
    klong_text = klong_text + '\n'
  return klong_text, gened_word

# main
def main(klong_text):
  gened_klong = []
  splitted = split_klong(klong_text)
  if main_check(klong_text) == True:
    wak_num = len(splitted)
    klong_text, gened_klong = gen_klong(klong_text, gened_klong)
    return klong_text
  else:
    return main_check(klong_text)

import gradio as gr

iface = gr.Interface(fn=main, inputs="text", outputs="text")
iface.launch(share=True, debug=True)