File size: 5,323 Bytes
21d2052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc8d63b
21d2052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc8d63b
21d2052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# -*- coding: utf-8 -*-
import numpy as np
import streamlit as st
from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast


model_dir = "snoop2head/kogpt-conditional-2"
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    model_dir,
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    mask_token="<mask>",
)


@st.cache
def load_model(model_name):
    model = AutoModelWithLMHead.from_pretrained(model_name)
    return model


model = load_model(model_dir)
print("loaded model completed")


def find_nth(haystack, needle, n):
    start = haystack.find(needle)
    while start >= 0 and n > 1:
        start = haystack.find(needle, start + len(needle))
        n -= 1
    return start


def infer(input_ids, max_length, temperature, top_k, top_p):
    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=1,
    )
    return output_sequences


# prompts
st.title("์‚ผํ–‰์‹œ์˜ ๋‹ฌ์ธ KoGPT์ž…๋‹ˆ๋‹ค ๐Ÿฆ„")
st.write("ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜๊ณ  CTRL+Enter(CMD+Enter)์„ ๋ˆ„๋ฅด์„ธ์š” ๐Ÿค—")

# text and sidebars
default_value = "๋ฐ•์ˆ˜๋ฏผ"
sent = st.text_area("Text", default_value, max_chars=4, height=275)
max_length = st.sidebar.slider("์ƒ์„ฑ ๋ฌธ์žฅ ๊ธธ์ด๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”!", min_value=42, max_value=64)
temperature = st.sidebar.slider(
    "Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05
)
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0)
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)

print("slider sidebars rendering completed")

# make input sentence
emotion_list = ["ํ–‰๋ณต", "์ค‘๋ฆฝ", "๋ถ„๋…ธ", "ํ˜์˜ค", "๋†€๋žŒ", "์Šฌํ””", "๊ณตํฌ"]
main_emotion = st.sidebar.radio("์ฃผ์š” ๊ฐ์ •์„ ์„ ํƒํ•˜์„ธ์š”", emotion_list)
sub_emotion = st.sidebar.radio("๋‘ ๋ฒˆ์งธ ๊ฐ์ •์„ ์„ ํƒํ•˜์„ธ์š”", emotion_list)

print("radio sidebars rendering completed")

# create condition sentence
random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1)
random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1)
condition_sentence = f"{random_main_logit}๋งŒํผ {main_emotion}๊ฐ์ •์ธ ๋ฌธ์žฅ์ด๋‹ค. {random_sub_logit}๋งŒํผ {sub_emotion}๊ฐ์ •์ธ ๋ฌธ์žฅ์ด๋‹ค. "
condition_plus_input = condition_sentence + sent
print(condition_plus_input)


def infer_sentence(
    condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=2
):
    encoded_prompt = tokenizer.encode(
        condition_plus_input, add_special_tokens=False, return_tensors="pt"
    )
    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt
    output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
    print(output_sequences)

    generated_sequence = output_sequences[0]
    print(generated_sequence)

    # print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
    # generated_sequences = generated_sequence.tolist()
    # Decode text
    text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
    print(text)
    # Remove all text after the stop token
    stop_token = tokenizer.pad_token
    print(stop_token)
    text = text[: text.find(stop_token) if stop_token else None]
    print(text)

    condition_index = find_nth(text, "๋ฌธ์žฅ์ด๋‹ค", 2)
    text = text[condition_index + 5 :]
    text = text.strip()
    return text


def make_residual_conditional_samhaengshi(input_letter, condition_sentence):
    # make letter string into
    list_samhaengshi = []

    # initializing text and index for iteration purpose
    index = 0

    # iterating over the input letter string
    for index, letter_item in enumerate(input_letter):
        # initializing the input_letter
        if index == 0:
            residual_text = letter_item
            # print('residual_text:', residual_text)

        # infer and add to the output
        conditional_input = f"{condition_sentence} {residual_text}"
        inferred_sentence = infer_sentence(conditional_input, tokenizer)
        if index != 0:
            # remove previous sentence from the output
            print("inferred_sentence:", inferred_sentence)
            inferred_sentence = inferred_sentence.replace(
                list_samhaengshi[index - 1], ""
            ).strip()
        else:
            pass
        list_samhaengshi.append(inferred_sentence)

        # until the end of the input_letter, give the previous residual_text to the next iteration
        if index < len(input_letter) - 1:
            residual_sentence = list_samhaengshi[index]
            next_letter = input_letter[index + 1]
            residual_text = (
                f"{residual_sentence} {next_letter}"  #  previous sentence + next letter
            )
            print("residual_text", residual_text)

        elif index == len(input_letter) - 1:  # end of the input_letter
            # Concatenate strings in the list without intersection

            return list_samhaengshi


return_text = make_residual_conditional_samhaengshi(
    input_letter=sent, condition_sentence=condition_sentence
)

print(return_text)

st.write(return_text)