KLONG4 / app.py
Thanravee's picture
Update app.py
00fa1f8
from typing import List, Union
from pythainlp.tokenize import subword_tokenize,word_tokenize
from pythainlp.util import sound_syllable
from pythainlp.util import remove_tonemark
from pythainlp.khavee import KhaveeVerifier
import pythainlp as pythai
from pythainlp.tokenize import word_tokenize
from pythainlp.tokenize import subword_tokenize
from pythainlp.util import sound_syllable
from pythainlp.util import isthai
from pythainlp.transliterate import pronunciate
from pythainlp.spell import correct
from tqdm import tqdm
import numpy as np
import pandas as pd
kv = KhaveeVerifier()
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Thanravee/KarveeSaimai", local_files_only=False)
model = AutoModelForCausalLM.from_pretrained("Thanravee/KarveeSaimai", local_files_only=False)
# split text from \n to list and drop soi word -> splitted wak list (no soi)
def split_klong(klong_text):
splitted_klong = []
klong_list = klong_text.split('-')
klong_list = [klong for klong in klong_list if klong.strip()]
for i in range(len(klong_list)):
if i == 1 or i == 3 or i == 5:
klong = klong_list[i]
if klong[0] == ' ':
klong = klong[1:]
klong = klong.split(' ')
splitted_klong.append(klong[0])
else:
splitted_klong.append(klong_list[i].replace(' ', ''))
return splitted_klong
# subword tokenize wak with ssg and dict
def subword_token(wak, engine='ssg'):
subword_tokenized = subword_tokenize(wak, engine='ssg')
if len(subword_tokenized) != 5 and len(subword_tokenized) != 2:
subword_tokenized = subword_tokenize(wak, engine='dict')
return subword_tokenized
# check number of syllables -> [True, True, True, True, True, True, True, True] (len=8)
def subword_num(splitted_klong):
checked = []
two = [1,3,5]
five = [0,2,4,6]
for num in range(len(splitted_klong)):
if num in two:
checked.append(len(subword_token(splitted_klong[num])) == 2)
elif num in five:
checked.append(len(subword_token(splitted_klong[num])) == 5)
elif num == 7:
checked.append(len(subword_token(splitted_klong[num])) == 4)
return checked
# check what word tone is
def find_tone(word):
char_list = [*word]
if "่" in char_list or sound_syllable(word) == 'dead':
return "eak or dead"
elif "้" in char_list:
return "tou"
else:
return False
# check eaktou -> list[True, True, True, True, True, True, True, True] (len=8)
def check_eaktou(splitted_klong):
checked = []
for num in range(len(splitted_klong)):
tokenzied_wak = subword_token(splitted_klong[num])
if num == 0:
checked.append(find_tone(tokenzied_wak[3]) == "eak or dead" and find_tone(tokenzied_wak[4]) == 'tou')
elif num == 1:
checked.append(True)
elif num == 2:
checked.append(find_tone(tokenzied_wak[1]) == "eak or dead")
elif num == 3:
checked.append(find_tone(tokenzied_wak[0]) == 'eak or dead' and find_tone(tokenzied_wak[1]) == 'tou')
elif num == 4:
checked.append(find_tone(tokenzied_wak[2]) == 'eak or dead')
elif num == 5:
checked.append(find_tone(tokenzied_wak[1]) == 'eak or dead')
elif num == 6:
checked.append(find_tone(tokenzied_wak[1]) == "eak or dead" and find_tone(tokenzied_wak[4]) == 'tou')
elif num == 7:
checked.append(find_tone(tokenzied_wak[0]) == "eak or dead" and find_tone(tokenzied_wak[1]) == 'tou')
return checked
# last sound of wak from pronunciate tokenized last word of each wak
# ex [เสียงลือเสียงเล่าอ้าง] -> [อ้าง]
def sound_words(splitted_klong):
sound_list = []
for wak in splitted_klong:
list_char = [*wak]
if " " in list_char:
wak = wak.split(" ")
wak = wak[0]
wak = word_tokenize(wak, engine="newmm")
pronounce_word = pronunciate(wak[-1], engine="w2p")
sound_list.append(pronounce_word.replace('ฺ', '').split('-')[-1])
return sound_list
# check sampas -> [True, True, True]
# [0] = sampas wak 2-3, [1] = sampas wak 2-4, [2] sampas wak 4-7
def check_sampas(sound_list):
checked = []
if len(sound_list) > 2:
checked.append(kv.is_sumpus(sound_list[1],sound_list[2]))
if len(sound_list) > 4:
checked.append(kv.is_sumpus(sound_list[1],sound_list[4]))
if len(sound_list) > 6:
checked.append(kv.is_sumpus(sound_list[3],sound_list[6]))
else:
checked.append(True)
return checked
def main_check(klong_text):
splitted_klong = split_klong(klong_text)
checked_subword_num = subword_num(splitted_klong)
if False in checked_subword_num:
false_index = checked_subword_num.index(False)
return 'syllable format error', false_index+1
else:
checked_eaktou = check_eaktou(splitted_klong)
if False in checked_eaktou:
false_index = checked_eaktou.index(False)
return 'eaktou format error', false_index+1
else:
sound_list = sound_words(splitted_klong)
checked_sampas = check_sampas(sound_list)
if False in checked_sampas:
wak_sampas = ['2 and 3', '2 and 5', '4 and 7']
return 'sampas format error', wak_sampas[checked_sampas.index(False)]
else:
return True
def gen_prob_next_token(text:str, model, tokenizer):
input_ids = tokenizer(text, return_tensors="pt")
#look at tensor shape
input_ids,input_ids['input_ids'].shape
#get logit of the next token
outputs = model(input_ids['input_ids'])
logits = outputs.logits
logits.shape #the size is equal to input token because it's predicting the next one
#convert logit to prob; use the logits of the last input token
import torch.nn.functional as F
probs = F.softmax(logits[:, -1, :], dim=-1).squeeze()
probs, probs.argmax()
#match prob with vocab
import pandas as pd
df = pd.DataFrame(tokenizer.vocab.items(), columns=['token', 'token_id']).sort_values('token_id').reset_index(drop=True)
df['prob'] = probs.detach().numpy()
possible_token = df.sort_values('prob',ascending=False).reset_index()
thai_only = [x if isthai(x) else None for x in possible_token['token']] # thai only
possible_token['token'] = thai_only
possible_token = possible_token.dropna()
return possible_token
# filter broken word and get passed only 100 words
def gen_rules(probs, fast_gen=True):
passed = []
limiter = 5 if fast_gen else 100000000
for prob in probs:
if len(check_word(prob)) > 1 and len(subword_token(prob)) == 1 and '-' not in pronunciate(prob) and len(passed) <= limiter:
passed.append(correct(prob))
return passed
def check_word(word):
alphabets = [alp for alp in [*word] if alp not in ['่','้','๊','๋','์']]
if '์' in [*word]:
alphabets = [*word][:-2]
return alphabets
def generator(klong):
prob = gen_prob_next_token(klong, model, tokenizer)
new_prob = gen_rules(prob['token'].tolist())
return new_prob
# get word with sampas
def get_sampassed(data:list, sampaswith):
passed = []
counter_exception = 0
for possible_word in tqdm(data):
possible_sampas = pronunciate(possible_word).split('-')[-1] # reduce word dimension
sampaswith = pronunciate(sampaswith).split('-')[-1] # reduce word dimension
try:
if kv.is_sumpus(possible_sampas, sampaswith):
passed.append(possible_word)
except IndexError:
counter_exception += 1
continue
assert len(passed) != counter_exception # if this failed mena that this function skipped all sampass which shouldn't be the case
return passed
# get word with aek or too
def get_aek_too(data:list, ktype='aek'):
passed = []
for possible_word in tqdm(data):
if kv.check_aek_too(possible_word) == ktype:
passed.append(possible_word)
return passed
def tone_gen(klong_text, gened_word, word_mark='no', sampas=False):
splitted_klong = split_klong(klong_text)
if word_mark == 'no' and sampas == False:
probs = generator(klong_text)
for prob in probs:
if prob not in gened_word:
gened_word.append(prob)
return prob, gened_word
elif word_mark == 'aek' and sampas == False:
probs = generator(klong_text)
aek = get_aek_too(probs)
for prob in aek:
if prob not in gened_word:
gened_word.append(prob)
return prob, gened_word
elif word_mark == 'too' and sampas == False:
probs = generator(klong_text)
too = get_aek_too(probs, 'too')
for prob in too:
if prob not in gened_word:
gened_word.append(prob)
return prob, gened_word
elif sampas == True and word_mark == 'no':
probs = gen_prob_next_token(klong_text, model, tokenizer)
probs = probs['token'][:500]
passed = get_sampassed(probs, sound_words(splitted_klong)[1])
for prob in passed:
if prob not in gened_word:
gened_word.append(prob)
return prob, gened_word
elif sampas == True and word_mark == 'too':
probs = gen_prob_next_token(klong_text, model, tokenizer)
probs = probs['token'][:500]
passed = get_sampassed(probs, sound_words(splitted_klong)[3])
for prob in passed:
if prob not in gened_word and kv.check_aek_too(prob) == 'too':
gened_word.append(prob)
return prob, gened_word
def gen_klong(klong_text_input, gened_word):
splitted_klong = split_klong(klong_text_input)
klong_text = klong_text_input
# วรรค 2, 4, 6
if len(splitted_klong) in [1, 3, 5]:
word_gen = 2
if len(splitted_klong) == 1:
# ฉันทลักษณ์ (none, none(sampas))
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
klong_text = klong_text + '-'
elif len(splitted_klong) == 3:
# ฉันทลักษณ์ (aek, too(sampas))
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek')
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word, 'too')
klong_text = klong_text + prob
klong_text = klong_text + '-'
elif len(splitted_klong) == 5:
# ฉันทลักษณ์ (none, aek)
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek')
klong_text = klong_text + prob
klong_text = klong_text + '-'
# วรรค 3, 5, 7
elif len(splitted_klong) in [2, 4, 6]:
word_gen = 5
if len(splitted_klong) == 2:
# ฉันทลักษณ์ (none, aek, none, none, none(sampas))
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek')
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
sampas_word = sound_words(splitted_klong)[1]
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='no', sampas=True)
klong_text = klong_text + prob
klong_text = klong_text + '-'
elif len(splitted_klong) == 4:
# ฉันทลักษณ์ (none, none, aek, none, none(sampas))
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek')
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
sampas_word = sound_words(splitted_klong)[1]
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='no', sampas=True)
klong_text = klong_text + prob
klong_text = klong_text + '-'
elif len(splitted_klong) == 6:
# ฉันทลักษณ์ (none, aek, none, none, too(sampas))
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek')
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
sampas_word = sound_words(splitted_klong)[1]
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='too', sampas=True)
klong_text = klong_text + prob
klong_text = klong_text + '-'
# วรรค 8
elif len(splitted_klong) == 7:
# ฉันทลักษณ์ (eak, too, none, none)
word_gen = 4
prob, gened_word = tone_gen(klong_text, gened_word, word_mark='aek')
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word, 'too')
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
prob, gened_word = tone_gen(klong_text, gened_word)
klong_text = klong_text + prob
klong_text = klong_text + '\n'
return klong_text, gened_word
# main
def main(klong_text):
gened_klong = []
splitted = split_klong(klong_text)
if main_check(klong_text) == True:
wak_num = len(splitted)
klong_text, gened_klong = gen_klong(klong_text, gened_klong)
return klong_text
else:
return main_check(klong_text)
import gradio as gr
iface = gr.Interface(fn=main, inputs="text", outputs="text")
iface.launch(debug=True)