Spaces:
Runtime error
Runtime error
import gradio as gr | |
def update(name): | |
return f"Welcome to Gradio, {name}!" | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown(f"๊ฐ ์ง๋ฌธ์ ๋๋ต ํ Enter ํด์ฃผ์ธ์.\n\n") | |
with gr.Row(): | |
topic = gr.Textbox(label="Topic", placeholder="๋ํ ์ฃผ์ ๋ฅผ ์ ํด์ฃผ์ธ์ (e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์ , ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...)") | |
with gr.Row(): | |
with gr.Column(): | |
addr = gr.Textbox(label="์ง์ญ", placeholder="e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์ , ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...") | |
age = gr.Textbox(label="๋์ด", placeholder="e.g. 20๋ ๋ฏธ๋ง, 40๋, 70๋ ์ด์, etc...") | |
sex = gr.Textbox(label="์ฑ๋ณ", placeholder="e.g. ๋จ์ฑ, ์ฌ์ฑ, etc...") | |
with gr.Column(): | |
addr = gr.Textbox(label="์ง์ญ", placeholder="e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์ , ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...") | |
age = gr.Textbox(label="๋์ด", placeholder="e.g. 20๋ ๋ฏธ๋ง, 40๋, 70๋ ์ด์, etc...") | |
sex = gr.Textbox(label="์ฑ๋ณ", placeholder="e.g. ๋จ์ฑ, ์ฌ์ฑ, etc...") | |
out = gr.Textbox() | |
btn = gr.Button("Run") | |
btn.click(fn=update, inputs=inp, outputs=out) | |
demo.launch() | |
def main(model_name): | |
warnings.filterwarnings("ignore") | |
tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b') | |
special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ#', '#@๊ณ์ #', '#@์ ์#', '#@์ ๋ฒ#', '#@๊ธ์ต#', '#@๋ฒํธ#', '#@์ฃผ์#', '#@์์#', '#@๊ธฐํ#']} | |
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.resize_token_embeddings(len(tokenizer)) | |
model = model.cuda() | |
info = "" | |
while True: | |
if info == "": | |
print( | |
f"์ง๊ธ๋ถํฐ ๋ํ ์ ๋ณด๋ฅผ ์ ๋ ฅ ๋ฐ๊ฒ ์ต๋๋ค.\n" | |
f"๊ฐ ์ง๋ฌธ์ ๋๋ต ํ Enter ํด์ฃผ์ธ์.\n" | |
f"์๋ฌด ์ ๋ ฅ ์์ด Enter ํ ๊ฒฝ์ฐ, ๋ฏธ๋ฆฌ ์ง์ ๋ ๊ฐ ์ค ๋๋ค์ผ๋ก ์ ํ๊ฒ ๋ฉ๋๋ค.\n" | |
) | |
time.sleep(1) | |
yon = "no" | |
else: | |
yon = input( | |
f"์ด์ ๋ํ ์ ๋ณด๋ฅผ ๊ทธ๋๋ก ์ ์งํ ๊น์? (yes : ์ ์ง, no : ์๋ก ์์ฑ) :" | |
) | |
if yon == "no": | |
info = "์ผ์ ๋ํ " | |
topic = input("๋ํ ์ฃผ์ ๋ฅผ ์ ํด์ฃผ์ธ์ (e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์ , ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...) :") | |
if topic == "": | |
topic = random.choice(['์ฌ๊ฐ ์ํ', '์์ฌ/๊ต์ก', '๋ฏธ์ฉ๊ณผ ๊ฑด๊ฐ', '์์๋ฃ', '์๊ฑฐ๋(์ผํ)', '์ผ๊ณผ ์ง์ ', '์ฃผ๊ฑฐ์ ์ํ', '๊ฐ์ธ ๋ฐ ๊ด๊ณ', 'ํ์ฌ']) | |
print(topic) | |
info += topic + "<sep>" | |
def ask_info(who, ment): | |
print(ment) | |
text = who + ":" | |
addr = input("์ด๋ ์ฌ์ธ์? (e.g. ์์ธํน๋ณ์, ์ ์ฃผ๋, etc...) :").strip() | |
if addr == "": | |
addr = random.choice(['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข ํน๋ณ์์น์', '๊ธฐํ']) | |
print(addr) | |
text += addr + " " | |
age = input("๋์ด๊ฐ? (e.g. 20๋, 70๋ ์ด์, etc...) :").strip() | |
if age == "": | |
age = random.choice(['20๋', '30๋', '50๋', '20๋ ๋ฏธ๋ง', '60๋', '40๋', '70๋ ์ด์']) | |
print(age) | |
text += age + " " | |
sex = input("์ฑ๋ณ์ด? (e.g. ๋จ์ฑ, ์ฌ์ฑ, etc... (?)) :").strip() | |
if sex == "": | |
sex = random.choice(['๋จ์ฑ', '์ฌ์ฑ']) | |
print(sex) | |
text += sex + "<sep>" | |
return text | |
info += ask_info(who="P01", ment=f"\n๋น์ ์ ๋ํด ์๋ ค์ฃผ์ธ์.\n") | |
info += ask_info(who="P02", ment=f"\n์ฑ๋ด์ ๋ํด ์๋ ค์ฃผ์ธ์.\n") | |
pp = info.replace('<sep>', '\n') | |
print( | |
f"\n----------------\n" | |
f"<์ ๋ ฅ ์ ๋ณด ํ์ธ> (P01 : ๋น์ , P02 : ์ฑ๋ด)\n" | |
f"{pp}" | |
f"----------------\n" | |
f"๋ํ๋ฅผ ์ข ๋ฃํ๊ณ ์ถ์ผ๋ฉด ์ธ์ ๋ ์ง 'end' ๋ผ๊ณ ๋งํด์ฃผ์ธ์~\n" | |
) | |
talk = [] | |
switch = True | |
switch2 = True | |
while True: | |
inp = "P01<sos>" | |
myinp = input("๋น์ : ") | |
if myinp == "end": | |
print("๋ํ ์ข ๋ฃ!") | |
break | |
inp += myinp + "<eos>" | |
talk.append(inp) | |
talk.append("P02<sos>") | |
while True: | |
now_inp = info + "".join(talk) | |
inpu = tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt') | |
seq_len = inpu.input_ids.size(1) | |
if seq_len > 512 * 0.8 and switch: | |
print( | |
f"<์ฃผ์> ํ์ฌ ๋ํ ๊ธธ์ด๊ฐ ๊ณง ์ต๋ ๊ธธ์ด์ ๋๋ฌํฉ๋๋ค. ({seq_len} / 512)" | |
) | |
switch = False | |
if seq_len >= 512 and switch2: | |
print("<์ฃผ์> ๋ํ ๊ธธ์ด๊ฐ ๋๋ฌด ๊ธธ์ด์ก๊ธฐ ๋๋ฌธ์, ์ดํ ๋ํ๋ ๋งจ ์์ ๋ฐํ๋ฅผ ์กฐ๊ธ์ฉ ์ง์ฐ๋ฉด์ ์งํ๋ฉ๋๋ค.") | |
talk = talk[1:] | |
switch2 = False | |
else: | |
break | |
out = model.generate( | |
inputs=inpu.input_ids.cuda(), | |
attention_mask=inpu.attention_mask.cuda(), | |
max_length=512, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.encode('<eos>')[0] | |
) | |
output = tokenizer.batch_decode(out) | |
print("์ฑ๋ด : " + output[0][len(now_inp):-5]) | |
talk[-1] += output[0][len(now_inp):] | |
again = input(f"๋ค๋ฅธ ๋ํ๋ฅผ ์์ํ ๊น์? (yes : ์๋ก์ด ์์, no : ์ข ๋ฃ) :") | |
if again == "no": | |
break | |