# To run: funix main.py
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import typing
from funix import funix
from funix.hint import HTML
low_memory = True # Set to True to run on mobile devices
import os
hf_token = os.environ.get("HF_TOKEN")
ku_gpt_tokenizer = AutoTokenizer.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
chj_gpt_tokenizer = AutoTokenizer.from_pretrained("TURX/chj-gpt2", token=hf_token)
wakagpt_tokenizer = AutoTokenizer.from_pretrained("TURX/wakagpt", token=hf_token)
ku_gpt_model = AutoModelForCausalLM.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
chj_gpt_model = AutoModelForCausalLM.from_pretrained("TURX/chj-gpt2", token=hf_token)
wakagpt_model = AutoModelForCausalLM.from_pretrained("TURX/wakagpt", token=hf_token)
print("Models loaded successfully.")
model_name_map = {
"Kyoto University GPT-2 (Modern)": "ku-gpt2",
"CHJ GPT-2 (Classical)": "chj-gpt2",
"Waka GPT": "wakagpt",
}
waka_type_map = {
"kana": "[仮名]",
"original": "[原文]",
"aligned": "[整形]",
}
@funix(
title=" Home",
description="""
Japanese Language Models
Final Project, STAT 453 Spring 2024, University of Wisconsin-Madison
Author: Ruixuan Tu (ruixuan@cs.wisc.edu, https://turx.tokyo)
Navigate the apps using the left sidebar.
"""
)
def home():
return
@funix(disable=True)
def __generate(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, prompt: str,
do_sample: bool, num_beams: int, num_beam_groups: int, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, num_return_sequences: int
) -> str:
global low_memory
inputs = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, low_memory=low_memory, do_sample=do_sample, num_beams=num_beams, num_beam_groups=num_beam_groups, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
@funix(
title="Custom Prompt Japanese GPT-2",
description="""
Japanese GPT-2
Let a GPT-2 model to complete a Japanese sentence for you.
""",
argument_labels={
"prompt": "Prompt in Japanese",
"model_type": "Model Type",
"max_new_tokens": "Max New Tokens to Generate",
"do_sample": "Do Sample",
"num_beams": "Number of Beams",
"num_beam_groups": "Number of Beam Groups",
"max_new_tokens": "Max New Tokens",
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"repetition_penalty": "Repetition Penalty",
"num_return_sequences": "Number of Sequences to Return",
},
widgets={
"num_beams": "slider[1,10,1]",
"num_beam_groups": "slider[1,5,1]",
"max_new_tokens": "slider[1,512,1]",
"temperature": "slider[0.0,1.0,0.01]",
"top_k": "slider[1,100,0.1]",
"top_p": "slider[0.0,1.0,0.01]",
"repetition_penalty": "slider[1.0,2.0,0.01]",
"num_return_sequences": "slider[1,5,1]",
}
)
def prompt(prompt: str = "こんにちは。", model_type: typing.Literal["Kyoto University GPT-2 (Modern)", "CHJ GPT-2 (Classical)", "Waka GPT"] = "Kyoto University GPT-2 (Modern)",
do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
) -> HTML:
model_name = model_name_map[model_type]
if model_name == "ku-gpt2":
tokenizer = ku_gpt_tokenizer
model = ku_gpt_model
elif model_name == "chj-gpt2":
tokenizer = chj_gpt_tokenizer
model = chj_gpt_model
elif model_name == "wakagpt":
tokenizer = wakagpt_tokenizer
model = wakagpt_model
else:
raise NotImplementedError(f"Unsupported model: {model_name}")
generated = __generate(tokenizer, model, prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
return HTML("".join([f"{i}
" for i in generated]))
@funix(
title="WakaGPT Poem Composer",
description="""
WakaGPT Poem Composer
Generate a Japanese waka poem in 5-7-5-7-7 form using WakaGPT. A sample poem (Kokinshu 169) is provided below:
Preface: 秋立つ日よめる
Author: 敏行 藤原敏行朝臣 (018)
Kana (Kana only with Separator): あききぬと−めにはさやかに−みえねとも−かせのおとにそ−おとろかれぬる
Original (Kana + Kanji without Separator): あききぬとめにはさやかに見えねとも風のおとにそおとろかれぬる
Aligned (Kana + Kanji with Separator): あききぬと−めにはさやかに−見えねとも−風のおとにそ−おとろかれぬる
""",
argument_labels={
"preface": "Preface (Kotobagaki) in Japanese (optional)",
"author": "Author Name in Japanese (optional)",
"first_line": "First Line of Poem in Japanese (optional)",
"type": "Waka Type",
"remaining_lines": "Remaining Lines of Poem",
"do_sample": "Do Sample",
"num_beams": "Number of Beams",
"num_beam_groups": "Number of Beam Groups",
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"repetition_penalty": "Repetition Penalty",
"num_return_sequences": "Number of Sequences to Return (at Maximum)",
},
widgets={
"remaining_lines": "slider[1,5,1]",
"num_beams": "slider[1,10,1]",
"num_beam_groups": "slider[1,5,1]",
"temperature": "slider[0.0,1.0,0.01]",
"top_k": "slider[1,100,0.1]",
"top_p": "slider[0.0,1.0,0.01]",
"repetition_penalty": "slider[1.0,2.0,0.01]",
"num_return_sequences": "slider[1,5,1]",
}
)
def waka(preface: str = "", author: str = "", first_line: str = "あききぬと−めにはさやかに−みえねとも", type: typing.Literal["Kana", "Original", "Aligned"] = "Kana", remaining_lines: int = 2,
do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
) -> HTML:
waka_prompt = ""
if preface:
waka_prompt += "[詞書] " + preface + "\n"
if author:
waka_prompt += "[作者] " + author + "\n"
token_counts = [5, 7, 5, 7, 7]
max_new_tokens = sum(token_counts[-remaining_lines:])
first_line = first_line.strip()
# add separators
if type.lower() in ["kana", "aligned"]:
if first_line == "":
max_new_tokens += 4
else:
first_line += "−" if first_line[-1] != "−" else first_line
max_new_tokens += remaining_lines - 1 # remaining separators
waka_prompt += waka_type_map[type.lower()] + " " + first_line
info = f"""
Prompt: {waka_prompt}
Max New Tokens: {max_new_tokens}
"""
generated = __generate(wakagpt_tokenizer, wakagpt_model, waka_prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
removed = 0
checked_generated = []
if type.lower() == "kana":
def check(seq):
poem = first_line + seq[len(waka_prompt) - 1:]
parts = poem.split("−")
if len(parts) == 5 and all(len(part) == token_counts[i] for i, part in enumerate(parts)):
checked_generated.append(poem)
else:
nonlocal removed
removed += 1
for i in generated:
check(i)
else:
checked_generated = [first_line + i[len(waka_prompt) - 1:] for i in generated]
generated = [f"{i}
" for i in checked_generated]
return info + f"Removed Malformed: {removed}
Results:
{''.join(generated)}"
if __name__ == "__main__":
print(prompt("こんにちは", "Kyoto University GPT-2 (Modern)", num_beams=5, num_return_sequences=5))