from typing import List, Union from pythainlp.tokenize import subword_tokenize,word_tokenize from pythainlp.util import sound_syllable from pythainlp.util import remove_tonemark from pythainlp.khavee import KhaveeVerifier import pythainlp as pythai from pythainlp.tokenize import word_tokenize from pythainlp.tokenize import subword_tokenize from pythainlp.util import sound_syllable from pythainlp.util import isthai from pythainlp.transliterate import pronunciate from pythainlp.spell import correct from tqdm import tqdm import numpy as np import pandas as pd kv = KhaveeVerifier() from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("Thanravee/KarveeSaimai", local_files_only=False) model = AutoModelForCausalLM.from_pretrained("Thanravee/KarveeSaimai", local_files_only=False) # split text from \n to list and drop soi word -> splitted wak list (no soi) def split_klong(klong_text): splitted_klong = [] klong_list = klong_text.split('-') klong_list = [klong for klong in klong_list if klong.strip()] for i in range(len(klong_list)): if i == 1 or i == 3 or i == 5: klong = klong_list[i] if klong[0] == ' ': klong = klong[1:] klong = klong.split(' ') splitted_klong.append(klong[0]) else: splitted_klong.append(klong_list[i].replace(' ', '')) return splitted_klong # subword tokenize wak with ssg and dict def subword_token(wak, engine='ssg'): subword_tokenized = subword_tokenize(wak, engine='ssg') if len(subword_tokenized) != 5 and len(subword_tokenized) != 2: subword_tokenized = subword_tokenize(wak, engine='dict') return subword_tokenized # check number of syllables -> [True, True, True, True, True, True, True, True] (len=8) def subword_num(splitted_klong): checked = [] two = [1,3,5] five = [0,2,4,6] for num in range(len(splitted_klong)): if num in two: checked.append(len(subword_token(splitted_klong[num])) == 2) elif num in five: checked.append(len(subword_token(splitted_klong[num])) == 5) elif num == 7: checked.append(len(subword_token(splitted_klong[num])) == 4) return checked # check what word tone is def find_tone(word): char_list = [*word] if "่" in char_list or sound_syllable(word) == 'dead': return "eak or dead" elif "้" in char_list: return "tou" else: return False # check eaktou -> list[True, True, True, True, True, True, True, True] (len=8) def check_eaktou(splitted_klong): checked = [] for num in range(len(splitted_klong)): tokenzied_wak = subword_token(splitted_klong[num]) if num == 0: checked.append(find_tone(tokenzied_wak[3]) == "eak or dead" and find_tone(tokenzied_wak[4]) == 'tou') elif num == 1: checked.append(True) elif num == 2: checked.append(find_tone(tokenzied_wak[1]) == "eak or dead") elif num == 3: checked.append(find_tone(tokenzied_wak[0]) == 'eak or dead' and find_tone(tokenzied_wak[1]) == 'tou') elif num == 4: checked.append(find_tone(tokenzied_wak[2]) == 'eak or dead') elif num == 5: checked.append(find_tone(tokenzied_wak[1]) == 'eak or dead') elif num == 6: checked.append(find_tone(tokenzied_wak[1]) == "eak or dead" and find_tone(tokenzied_wak[4]) == 'tou') elif num == 7: checked.append(find_tone(tokenzied_wak[0]) == "eak or dead" and find_tone(tokenzied_wak[1]) == 'tou') return checked # last sound of wak from pronunciate tokenized last word of each wak # ex [เสียงลือเสียงเล่าอ้าง] -> [อ้าง] def sound_words(splitted_klong): sound_list = [] for wak in splitted_klong: list_char = [*wak] if " " in list_char: wak = wak.split(" ") wak = wak[0] wak = word_tokenize(wak, engine="newmm") pronounce_word = pronunciate(wak[-1], engine="w2p") sound_list.append(pronounce_word.replace('ฺ', '').split('-')[-1]) return sound_list # check sampas -> [True, True, True] # [0] = sampas wak 2-3, [1] = sampas wak 2-4, [2] sampas wak 4-7 def check_sampas(sound_list): checked = [] if len(sound_list) > 2: checked.append(kv.is_sumpus(sound_list[1],sound_list[2])) if len(sound_list) > 4: checked.append(kv.is_sumpus(sound_list[1],sound_list[4])) if len(sound_list) > 6: checked.append(kv.is_sumpus(sound_list[3],sound_list[6])) else: checked.append(True) return checked def main_check(klong_text): splitted_klong = split_klong(klong_text) checked_subword_num = subword_num(splitted_klong) if False in checked_subword_num: false_index = checked_subword_num.index(False) return 'syllable format error', false_index+1 else: checked_eaktou = check_eaktou(splitted_klong) if False in checked_eaktou: false_index = checked_eaktou.index(False) return 'eaktou format error', false_index+1 else: sound_list = sound_words(splitted_klong) checked_sampas = check_sampas(sound_list) if False in checked_sampas: wak_sampas = ['2 and 3', '2 and 5', '4 and 7'] return 'sampas format error', wak_sampas[checked_sampas.index(False)] else: return True def gen_prob_next_token(text:str, model, tokenizer): input_ids = tokenizer(text, return_tensors="pt") #look at tensor shape input_ids,input_ids['input_ids'].shape #get logit of the next token outputs = model(input_ids['input_ids']) logits = outputs.logits logits.shape #the size is equal to input token because it's predicting the next one #convert logit to prob; use the logits of the last input token import torch.nn.functional as F probs = F.softmax(logits[:, -1, :], dim=-1).squeeze() probs, probs.argmax() #match prob with vocab import pandas as pd df = pd.DataFrame(tokenizer.vocab.items(), columns=['token', 'token_id']).sort_values('token_id').reset_index(drop=True) df['prob'] = probs.detach().numpy() possible_token = df.sort_values('prob',ascending=False).reset_index() thai_only = [x if isthai(x) else None for x in possible_token['token']] # thai only possible_token['token'] = thai_only possible_token = possible_token.dropna() return possible_token # filter broken word and get passed only 100 words def gen_rules(probs, fast_gen=True): passed = [] limiter = 5 if fast_gen else 100000000 for prob in probs: if len(check_word(prob)) > 1 and len(subword_token(prob)) == 1 and '-' not in pronunciate(prob) and len(passed) <= limiter: passed.append(correct(prob)) return passed def check_word(word): alphabets = [alp for alp in [*word] if alp not in ['่','้','๊','๋','์']] if '์' in [*word]: alphabets = [*word][:-2] return alphabets def generator(klong): prob = gen_prob_next_token(klong, model, tokenizer) new_prob = gen_rules(prob['token'].tolist()) return new_prob # get word with sampas def get_sampassed(data:list, sampaswith): passed = [] counter_exception = 0 for possible_word in tqdm(data): possible_sampas = pronunciate(possible_word).split('-')[-1] # reduce word dimension sampaswith = pronunciate(sampaswith).split('-')[-1] # reduce word dimension try: if kv.is_sumpus(possible_sampas, sampaswith): passed.append(possible_word) except IndexError: counter_exception += 1 continue assert len(passed) != counter_exception # if this failed mena that this function skipped all sampass which shouldn't be the case return passed # get word with aek or too def get_aek_too(data:list, ktype='aek'): passed = [] for possible_word in tqdm(data): if kv.check_aek_too(possible_word) == ktype: passed.append(possible_word) return passed def tone_gen(klong_text, gened_word, word_mark='no', sampas=False): splitted_klong = split_klong(klong_text) if word_mark == 'no' and sampas == False: probs = generator(klong_text) for prob in probs: if prob not in gened_word: gened_word.append(prob) return prob, gened_word elif word_mark == 'aek' and sampas == False: probs = generator(klong_text) aek = get_aek_too(probs) for prob in aek: if prob not in gened_word: gened_word.append(prob) return prob, gened_word elif word_mark == 'too' and sampas == False: probs = generator(klong_text) too = get_aek_too(probs, 'too') for prob in too: if prob not in gened_word: gened_word.append(prob) return prob, gened_word elif sampas == True and word_mark == 'no': probs = gen_prob_next_token(klong_text, model, tokenizer) probs = probs['token'][:500] passed = get_sampassed(probs, sound_words(splitted_klong)[1]) for prob in passed: if prob not in gened_word: gened_word.append(prob) return prob, gened_word elif sampas == True and word_mark == 'too': probs = gen_prob_next_token(klong_text, model, tokenizer) probs = probs['token'][:500] passed = get_sampassed(probs, sound_words(splitted_klong)[3]) for prob in passed: if prob not in gened_word and kv.check_aek_too(prob) == 'too': gened_word.append(prob) return prob, gened_word def gen_klong(klong_text_input, gened_word): splitted_klong = split_klong(klong_text_input) klong_text = klong_text_input # วรรค 2, 4, 6 if len(splitted_klong) in [1, 3, 5]: word_gen = 2 if len(splitted_klong) == 1: # ฉันทลักษณ์ (none, none(sampas)) prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob klong_text = klong_text + '-' elif len(splitted_klong) == 3: # ฉันทลักษณ์ (aek, too(sampas)) prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek') klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word, 'too') klong_text = klong_text + prob klong_text = klong_text + '-' elif len(splitted_klong) == 5: # ฉันทลักษณ์ (none, aek) prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek') klong_text = klong_text + prob klong_text = klong_text + '-' # วรรค 3, 5, 7 elif len(splitted_klong) in [2, 4, 6]: word_gen = 5 if len(splitted_klong) == 2: # ฉันทลักษณ์ (none, aek, none, none, none(sampas)) prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek') klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob sampas_word = sound_words(splitted_klong)[1] prob, gened_word = tone_gen(klong_text, gened_word, word_mark='no', sampas=True) klong_text = klong_text + prob klong_text = klong_text + '-' elif len(splitted_klong) == 4: # ฉันทลักษณ์ (none, none, aek, none, none(sampas)) prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek') klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob sampas_word = sound_words(splitted_klong)[1] prob, gened_word = tone_gen(klong_text, gened_word, word_mark='no', sampas=True) klong_text = klong_text + prob klong_text = klong_text + '-' elif len(splitted_klong) == 6: # ฉันทลักษณ์ (none, aek, none, none, too(sampas)) prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek') klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob sampas_word = sound_words(splitted_klong)[1] prob, gened_word = tone_gen(klong_text, gened_word, word_mark='too', sampas=True) klong_text = klong_text + prob klong_text = klong_text + '-' # วรรค 8 elif len(splitted_klong) == 7: # ฉันทลักษณ์ (eak, too, none, none) word_gen = 4 prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek') klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word, 'too') klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob prob, gened_word = tone_gen(klong_text, gened_word) klong_text = klong_text + prob klong_text = klong_text + '\n' return klong_text, gened_word # main def main(klong_text): gened_klong = [] splitted = split_klong(klong_text) if main_check(klong_text) == True: wak_num = len(splitted) klong_text, gened_klong = gen_klong(klong_text, gened_klong) return klong_text else: return main_check(klong_text) import gradio as gr iface = gr.Interface(fn=main, inputs="text", outputs="text") iface.launch(debug=True)