File size: 3,732 Bytes
21d2052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc8d63b
21d2052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85c2795
070689f
21d2052
 
85c2795
77a934c
21d2052
 
3569768
21d2052
 
070689f
21d2052
 
 
 
e324e06
21d2052
b14d53a
 
21d2052
 
 
 
 
 
 
 
 
 
 
 
dc8d63b
21d2052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85c2795
 
21d2052
 
 
 
85c2795
 
21d2052
 
 
 
 
 
85c2795
 
21d2052
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding: utf-8 -*-
import numpy as np
import streamlit as st
from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast

model_dir = "snoop2head/kogpt-conditional-2"
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    model_dir,
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    mask_token="<mask>",
)


@st.cache
def load_model(model_name):
    model = AutoModelWithLMHead.from_pretrained(model_name)
    return model


model = load_model(model_dir)
print("loaded model completed")


def find_nth(haystack, needle, n):
    start = haystack.find(needle)
    while start >= 0 and n > 1:
        start = haystack.find(needle, start + len(needle))
        n -= 1
    return start


def infer(input_ids, max_length, temperature, top_k, top_p):
    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=1,
    )
    return output_sequences


# prompts
st.title("์ฃผ์–ด์ง„ ๊ฐ์ •์— ๋งž๊ฒŒ ๋ฌธ์žฅ์„ ๋งŒ๋“œ๋Š” KoGPT์ž…๋‹ˆ๋‹ค ๐Ÿฆ„")
st.write("์ขŒ์ธก์— ๊ฐ์ •์ƒํƒœ์˜ ๋ณ€ํ™”๋ฅผ ์ฃผ๊ณ , CTRL+Enter(CMD+Enter)๋ฅผ ๋ˆ„๋ฅด์„ธ์š” ๐Ÿค—")

# text and sidebars
default_value = "์ˆ˜์ƒํ•œ ๋ฐค๋“ค์ด ๊ณ„์†๋˜๋˜ ๋‚  ์–ธ์  ๊ฐ€๋ถ€ํ„ฐ ๋‚˜๋Š”"
sent = st.text_area("Text", default_value, max_chars=30, height=50)
max_length = st.sidebar.slider("์ƒ์„ฑ ๋ฌธ์žฅ ๊ธธ์ด๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”!", min_value=42, max_value=64)
temperature = st.sidebar.slider(
    "Temperature", value=0.9, min_value=0.0, max_value=1.0, step=0.05
)
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0)
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=1.0)

print("slider sidebars rendering completed")

# make input sentence
emotion_list = ["ํ–‰๋ณต", "๋†€๋žŒ", "๋ถ„๋…ธ", "ํ˜์˜ค", "์Šฌํ””", "๊ณตํฌ", "์ค‘๋ฆฝ"]
main_emotion = st.sidebar.radio("์ฃผ์š” ๊ฐ์ •์„ ์„ ํƒํ•˜์„ธ์š”", emotion_list)
emotion_list.reverse()
sub_emotion = st.sidebar.radio("๋‘ ๋ฒˆ์งธ ๊ฐ์ •์„ ์„ ํƒํ•˜์„ธ์š”", emotion_list)

print("radio sidebars rendering completed")

# create condition sentence
random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1)
random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1)
condition_sentence = f"{random_main_logit}๋งŒํผ {main_emotion}๊ฐ์ •์ธ ๋ฌธ์žฅ์ด๋‹ค. {random_sub_logit}๋งŒํผ {sub_emotion}๊ฐ์ •์ธ ๋ฌธ์žฅ์ด๋‹ค. "
condition_plus_input = condition_sentence + sent
print(condition_plus_input)


def infer_sentence(
    condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=2
):
    encoded_prompt = tokenizer.encode(
        condition_plus_input, add_special_tokens=False, return_tensors="pt"
    )
    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt
    output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
    print(output_sequences)

    generated_sequence = output_sequences[0]
    print(generated_sequence)

    # Decode text
    text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
    print(text)

    # Remove all text after the pad token
    stop_token = tokenizer.pad_token
    print(stop_token)
    text = text[: text.find(stop_token) if stop_token else None]
    print(text)
    
    # Remove condition sentence
    condition_index = find_nth(text, "๋ฌธ์žฅ์ด๋‹ค", 2)
    text = text[condition_index + 5 :]
    text = text.strip()
    return text


return_text = infer_sentence(
    condition_plus_input=condition_plus_input, tokenizer=tokenizer
)

print(return_text)

st.write(return_text)