japanese-lm / main.py
TURX's picture
update
1331fa9
# To run: funix main.py
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import typing
from funix import funix
from funix.hint import HTML
low_memory = True # Set to True to run on mobile devices
import os
hf_token = os.environ.get("HF_TOKEN")
ku_gpt_tokenizer = AutoTokenizer.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
chj_gpt_tokenizer = AutoTokenizer.from_pretrained("TURX/chj-gpt2", token=hf_token)
wakagpt_tokenizer = AutoTokenizer.from_pretrained("TURX/wakagpt", token=hf_token)
ku_gpt_model = AutoModelForCausalLM.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
chj_gpt_model = AutoModelForCausalLM.from_pretrained("TURX/chj-gpt2", token=hf_token)
wakagpt_model = AutoModelForCausalLM.from_pretrained("TURX/wakagpt", token=hf_token)
print("Models loaded successfully.")
model_name_map = {
"Kyoto University GPT-2 (Modern)": "ku-gpt2",
"CHJ GPT-2 (Classical)": "chj-gpt2",
"Waka GPT": "wakagpt",
}
waka_type_map = {
"kana": "[ไปฎๅ]",
"original": "[ๅŽŸๆ–‡]",
"aligned": "[ๆ•ดๅฝข]",
}
@funix(
title=" Home",
description="""
<h1>Japanese Language Models</h1><hr>
Final Project, STAT 453 Spring 2024, University of Wisconsin-Madison<br>
Author: Ruixuan Tu (ruixuan@cs.wisc.edu, https://turx.tokyo)<hr>
Navigate the apps using the left sidebar.
"""
)
def home():
return
@funix(disable=True)
def __generate(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, prompt: str,
do_sample: bool, num_beams: int, num_beam_groups: int, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, num_return_sequences: int
) -> str:
global low_memory
inputs = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, low_memory=low_memory, do_sample=do_sample, num_beams=num_beams, num_beam_groups=num_beam_groups, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
@funix(
title="Custom Prompt Japanese GPT-2",
description="""
<h1>Japanese GPT-2</h1><hr>
Let a GPT-2 model to complete a Japanese sentence for you.
""",
argument_labels={
"prompt": "Prompt in Japanese",
"model_type": "Model Type",
"max_new_tokens": "Max New Tokens to Generate",
"do_sample": "Do Sample",
"num_beams": "Number of Beams",
"num_beam_groups": "Number of Beam Groups",
"max_new_tokens": "Max New Tokens",
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"repetition_penalty": "Repetition Penalty",
"num_return_sequences": "Number of Sequences to Return",
},
widgets={
"num_beams": "slider[1,10,1]",
"num_beam_groups": "slider[1,5,1]",
"max_new_tokens": "slider[1,512,1]",
"temperature": "slider[0.0,1.0,0.01]",
"top_k": "slider[1,100,0.1]",
"top_p": "slider[0.0,1.0,0.01]",
"repetition_penalty": "slider[1.0,2.0,0.01]",
"num_return_sequences": "slider[1,5,1]",
}
)
def prompt(prompt: str = "ใ“ใ‚“ใซใกใฏใ€‚", model_type: typing.Literal["Kyoto University GPT-2 (Modern)", "CHJ GPT-2 (Classical)", "Waka GPT"] = "Kyoto University GPT-2 (Modern)",
do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
) -> HTML:
model_name = model_name_map[model_type]
if model_name == "ku-gpt2":
tokenizer = ku_gpt_tokenizer
model = ku_gpt_model
elif model_name == "chj-gpt2":
tokenizer = chj_gpt_tokenizer
model = chj_gpt_model
elif model_name == "wakagpt":
tokenizer = wakagpt_tokenizer
model = wakagpt_model
else:
raise NotImplementedError(f"Unsupported model: {model_name}")
generated = __generate(tokenizer, model, prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
return HTML("".join([f"<p>{i}</p>" for i in generated]))
@funix(
title="WakaGPT Poem Composer",
description="""
<h1>WakaGPT Poem Composer</h1><hr>
Generate a Japanese waka poem in 5-7-5-7-7 form using WakaGPT. A sample poem (Kokinshu 169) is provided below:<br>
Preface: ็ง‹็ซ‹ใคๆ—ฅใ‚ˆใ‚ใ‚‹<br>
Author: ๆ•่กŒ ่—คๅŽŸๆ•่กŒๆœ่‡ฃ (018)<br>
Kana (Kana only with Separator): ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’ใฟใˆใญใจใ‚‚โˆ’ใ‹ใ›ใฎใŠใจใซใโˆ’ใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹<br>
Original (Kana + Kanji without Separator): ใ‚ใใใฌใจใ‚ใซใฏใ•ใ‚„ใ‹ใซ่ฆ‹ใˆใญใจใ‚‚้ขจใฎใŠใจใซใใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹<br>
Aligned (Kana + Kanji with Separator): ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’่ฆ‹ใˆใญใจใ‚‚โˆ’้ขจใฎใŠใจใซใโˆ’ใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹
""",
argument_labels={
"preface": "Preface (Kotobagaki) in Japanese (optional)",
"author": "Author Name in Japanese (optional)",
"first_line": "First Line of Poem in Japanese (optional)",
"type": "Waka Type",
"remaining_lines": "Remaining Lines of Poem",
"do_sample": "Do Sample",
"num_beams": "Number of Beams",
"num_beam_groups": "Number of Beam Groups",
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"repetition_penalty": "Repetition Penalty",
"num_return_sequences": "Number of Sequences to Return (at Maximum)",
},
widgets={
"remaining_lines": "slider[1,5,1]",
"num_beams": "slider[1,10,1]",
"num_beam_groups": "slider[1,5,1]",
"temperature": "slider[0.0,1.0,0.01]",
"top_k": "slider[1,100,0.1]",
"top_p": "slider[0.0,1.0,0.01]",
"repetition_penalty": "slider[1.0,2.0,0.01]",
"num_return_sequences": "slider[1,5,1]",
}
)
def waka(preface: str = "", author: str = "", first_line: str = "ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’ใฟใˆใญใจใ‚‚", type: typing.Literal["Kana", "Original", "Aligned"] = "Kana", remaining_lines: int = 2,
do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
) -> HTML:
waka_prompt = ""
if preface:
waka_prompt += "[่ฉžๆ›ธ] " + preface + "\n"
if author:
waka_prompt += "[ไฝœ่€…] " + author + "\n"
token_counts = [5, 7, 5, 7, 7]
max_new_tokens = sum(token_counts[-remaining_lines:])
first_line = first_line.strip()
# add separators
if type.lower() in ["kana", "aligned"]:
if first_line == "":
max_new_tokens += 4
else:
first_line += "โˆ’" if first_line[-1] != "โˆ’" else first_line
max_new_tokens += remaining_lines - 1 # remaining separators
waka_prompt += waka_type_map[type.lower()] + " " + first_line
info = f"""
Prompt: {waka_prompt}<br>
Max New Tokens: {max_new_tokens}<br>
"""
generated = __generate(wakagpt_tokenizer, wakagpt_model, waka_prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
removed = 0
checked_generated = []
if type.lower() == "kana":
def check(seq):
poem = first_line + seq[len(waka_prompt) - 1:]
parts = poem.split("โˆ’")
if len(parts) == 5 and all(len(part) == token_counts[i] for i, part in enumerate(parts)):
checked_generated.append(poem)
else:
nonlocal removed
removed += 1
for i in generated:
check(i)
else:
checked_generated = [first_line + i[len(waka_prompt) - 1:] for i in generated]
generated = [f"<p>{i}</p>" for i in checked_generated]
return info + f"Removed Malformed: {removed}<br>Results:<br>{''.join(generated)}"
if __name__ == "__main__":
print(prompt("ใ“ใ‚“ใซใกใฏ", "Kyoto University GPT-2 (Modern)", num_beams=5, num_return_sequences=5))