File size: 8,253 Bytes
d8a9bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# To run: funix main.py

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import typing
from funix import funix
from funix.hint import HTML

low_memory = True  # Set to True to run on mobile devices

ku_gpt_tokenizer = AutoTokenizer.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
chj_gpt_tokenizer = AutoTokenizer.from_pretrained("TURX/chj-gpt2")
wakagpt_tokenizer = AutoTokenizer.from_pretrained("TURX/wakagpt")
ku_gpt_model = AutoModelForCausalLM.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
chj_gpt_model = AutoModelForCausalLM.from_pretrained("TURX/chj-gpt2")
wakagpt_model = AutoModelForCausalLM.from_pretrained("TURX/wakagpt")

print("Models loaded successfully.")

model_name_map = {
    "Kyoto University GPT-2 (Modern)": "ku-gpt2",
    "CHJ GPT-2 (Classical)": "chj-gpt2",
    "Waka GPT": "wakagpt",
}

waka_type_map = {
    "kana": "[ไปฎๅ]",
    "original": "[ๅŽŸๆ–‡]",
    "aligned": "[ๆ•ดๅฝข]",
}


@funix(
    title=" Home",
    description="""
<h1>Japanese Language Models</h1><hr>
Final Project, STAT 453 Spring 2024, University of Wisconsin-Madison<br>
Author: Ruixuan Tu (ruixuan@cs.wisc.edu, https://turx.asia)<hr>
Navigate the apps using the left sidebar.
"""
)
def home():
    return


@funix(disable=True)
def __generate(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, prompt: str,
               do_sample: bool, num_beams: int, num_beam_groups: int, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, num_return_sequences: int
               ) -> str:
    global low_memory
    inputs = tokenizer(prompt, return_tensors="pt").input_ids
    outputs = model.generate(inputs, low_memory=low_memory, do_sample=do_sample, num_beams=num_beams, num_beam_groups=num_beam_groups, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)


@funix(
    title="Custom Prompt Japanese GPT-2", 
    description="""
<h1>Japanese GPT-2</h1><hr>
Let a GPT-2 model to complete a Japanese sentence for you.
""",
    argument_labels={
        "prompt": "Prompt in Japanese",
        "model_type": "Model Type",
        "max_new_tokens": "Max New Tokens to Generate",
        "do_sample": "Do Sample",
        "num_beams": "Number of Beams",
        "num_beam_groups": "Number of Beam Groups",
        "max_new_tokens": "Max New Tokens",
        "temperature": "Temperature",
        "top_k": "Top K",
        "top_p": "Top P",
        "repetition_penalty": "Repetition Penalty",
        "num_return_sequences": "Number of Sequences to Return",
    },
    widgets={
        "num_beams": "slider[1,10,1]",
        "num_beam_groups": "slider[1,5,1]",
        "max_new_tokens": "slider[1,512,1]",
        "temperature": "slider[0.0,1.0,0.01]",
        "top_k": "slider[1,100,0.1]",
        "top_p": "slider[0.0,1.0,0.01]",
        "repetition_penalty": "slider[1.0,2.0,0.01]",
        "num_return_sequences": "slider[1,5,1]",
    }
)
def prompt(prompt: str = "ใ“ใ‚“ใซใกใฏใ€‚", model_type: typing.Literal["Kyoto University GPT-2 (Modern)", "CHJ GPT-2 (Classical)", "Waka GPT"] = "Kyoto University GPT-2 (Modern)",
        do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
        ) -> HTML:
    model_name = model_name_map[model_type]
    if model_name == "ku-gpt2":
        tokenizer = ku_gpt_tokenizer
        model = ku_gpt_model
    elif model_name == "chj-gpt2":
        tokenizer = chj_gpt_tokenizer
        model = chj_gpt_model
    elif model_name == "wakagpt":
        tokenizer = wakagpt_tokenizer
        model = wakagpt_model
    else:
        raise NotImplementedError(f"Unsupported model: {model_name}")
    generated = __generate(tokenizer, model, prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
    return HTML("".join([f"<p>{i}</p>" for i in generated]))


@funix(
    title="WakaGPT Poem Composer",
    description="""
<h1>WakaGPT Poem Composer</h1><hr>
Generate a Japanese waka poem in 5-7-5-7-7 form using WakaGPT. A sample poem (Kokinshu 169) is provided below:<br>
    Preface: ็ง‹็ซ‹ใคๆ—ฅใ‚ˆใ‚ใ‚‹<br>
    Author: ๆ•่กŒ ่—คๅŽŸๆ•่กŒๆœ่‡ฃ (018)<br>
    Kana (Kana only with Separator): ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’ใฟใˆใญใจใ‚‚โˆ’ใ‹ใ›ใฎใŠใจใซใโˆ’ใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹<br>
    Original (Kana + Kanji without Separator): ใ‚ใใใฌใจใ‚ใซใฏใ•ใ‚„ใ‹ใซ่ฆ‹ใˆใญใจใ‚‚้ขจใฎใŠใจใซใใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹<br>
    Aligned (Kana + Kanji with Separator): ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’่ฆ‹ใˆใญใจใ‚‚โˆ’้ขจใฎใŠใจใซใโˆ’ใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹
""",
    argument_labels={
        "preface": "Preface (Kotobagaki) in Japanese (optional)",
        "author": "Author Name in Japanese (optional)",
        "first_line": "First Line of Poem in Japanese (optional)",
        "type": "Waka Type",
        "remaining_lines": "Remaining Lines of Poem",
        "do_sample": "Do Sample",
        "num_beams": "Number of Beams",
        "num_beam_groups": "Number of Beam Groups",
        "temperature": "Temperature",
        "top_k": "Top K",
        "top_p": "Top P",
        "repetition_penalty": "Repetition Penalty",
        "num_return_sequences": "Number of Sequences to Return (at Maximum)",
    },
    widgets={
        "remaining_lines": "slider[1,5,1]",
        "num_beams": "slider[1,10,1]",
        "num_beam_groups": "slider[1,5,1]",
        "temperature": "slider[0.0,1.0,0.01]",
        "top_k": "slider[1,100,0.1]",
        "top_p": "slider[0.0,1.0,0.01]",
        "repetition_penalty": "slider[1.0,2.0,0.01]",
        "num_return_sequences": "slider[1,5,1]",
    }
)
def waka(preface: str = "", author: str = "", first_line: str = "ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’ใฟใˆใญใจใ‚‚", type: typing.Literal["Kana", "Original", "Aligned"] = "Kana", remaining_lines: int = 2,
         do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
         ) -> HTML:
    waka_prompt = ""
    if preface:
        waka_prompt += "[่ฉžๆ›ธ] " + preface + "\n"
    if author:
        waka_prompt += "[ไฝœ่€…] " + author + "\n"
    token_counts = [5, 7, 5, 7, 7]
    max_new_tokens = sum(token_counts[-remaining_lines:])
    first_line = first_line.strip()

    # add separators
    if type.lower() in ["kana", "aligned"]:
        if first_line == "":
            max_new_tokens += 4
        else:
            first_line += "โˆ’" if first_line[-1] != "โˆ’" else first_line
            max_new_tokens += remaining_lines - 1  # remaining separators

    waka_prompt += waka_type_map[type.lower()] + " " + first_line
    info = f"""
    Prompt: {waka_prompt}<br>
    Max New Tokens: {max_new_tokens}<br>
    """
    yield info + "Generating Poem..."
    generated = __generate(wakagpt_tokenizer, wakagpt_model, waka_prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)

    removed = 0
    checked_generated = []
    if type.lower() in ["kana", "aligned"]:
        def check(seq):
            poem = first_line + seq[len(waka_prompt) - 1:]
            parts = poem.split("โˆ’")
            if len(parts) == 5 and all(len(part) == token_counts[i] for i, part in enumerate(parts)):
                checked_generated.append(poem)
            else:
                nonlocal removed
                removed += 1
        for i in generated:
            check(i)
    else:
        checked_generated = [first_line + i[len(waka_prompt) - 1:] for i in generated]

    generated = [f"<p>{i}</p>" for i in checked_generated]
    yield info + f"Removed Malformed: {removed}<br>Results:<br>{''.join(generated)}"


if __name__ == "__main__":
    print(prompt("ใ“ใ‚“ใซใกใฏ", "Kyoto University GPT-2 (Modern)", num_beams=5, num_return_sequences=5))