test_space / app.py
JYYong's picture
progr
65a4f6a
raw history blame
No virus
6.33 kB
import gradio as gr
def update(name):
return f"Welcome to Gradio, {name}!"
demo = gr.Blocks()
with demo:
gr.Markdown(f"๊ฐ ์งˆ๋ฌธ์— ๋Œ€๋‹ต ํ›„ Enter ํ•ด์ฃผ์„ธ์š”.\n\n")
with gr.Row():
topic = gr.Textbox(label="Topic", placeholder="๋Œ€ํ™” ์ฃผ์ œ๋ฅผ ์ •ํ•ด์ฃผ์„ธ์š” (e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...)")
with gr.Row():
with gr.Column():
addr = gr.Textbox(label="์ง€์—ญ", placeholder="e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...")
age = gr.Textbox(label="๋‚˜์ด", placeholder="e.g. 20๋Œ€ ๋ฏธ๋งŒ, 40๋Œ€, 70๋Œ€ ์ด์ƒ, etc...")
sex = gr.Textbox(label="์„ฑ๋ณ„", placeholder="e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc...")
with gr.Column():
addr = gr.Textbox(label="์ง€์—ญ", placeholder="e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...")
age = gr.Textbox(label="๋‚˜์ด", placeholder="e.g. 20๋Œ€ ๋ฏธ๋งŒ, 40๋Œ€, 70๋Œ€ ์ด์ƒ, etc...")
sex = gr.Textbox(label="์„ฑ๋ณ„", placeholder="e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc...")
out = gr.Textbox()
btn = gr.Button("Run")
btn.click(fn=update, inputs=inp, outputs=out)
demo.launch()
def main(model_name):
warnings.filterwarnings("ignore")
tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ„#', '#@๊ณ„์ •#', '#@์‹ ์›#', '#@์ „๋ฒˆ#', '#@๊ธˆ์œต#', '#@๋ฒˆํ˜ธ#', '#@์ฃผ์†Œ#', '#@์†Œ์†#', '#@๊ธฐํƒ€#']}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))
model = model.cuda()
info = ""
while True:
if info == "":
print(
f"์ง€๊ธˆ๋ถ€ํ„ฐ ๋Œ€ํ™” ์ •๋ณด๋ฅผ ์ž…๋ ฅ ๋ฐ›๊ฒ ์Šต๋‹ˆ๋‹ค.\n"
f"๊ฐ ์งˆ๋ฌธ์— ๋Œ€๋‹ต ํ›„ Enter ํ•ด์ฃผ์„ธ์š”.\n"
f"์•„๋ฌด ์ž…๋ ฅ ์—†์ด Enter ํ•  ๊ฒฝ์šฐ, ๋ฏธ๋ฆฌ ์ง€์ •๋œ ๊ฐ’ ์ค‘ ๋žœ๋ค์œผ๋กœ ์ •ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n"
)
time.sleep(1)
yon = "no"
else:
yon = input(
f"์ด์ „ ๋Œ€ํ™” ์ •๋ณด๋ฅผ ๊ทธ๋Œ€๋กœ ์œ ์ง€ํ• ๊นŒ์š”? (yes : ์œ ์ง€, no : ์ƒˆ๋กœ ์ž‘์„ฑ) :"
)
if yon == "no":
info = "์ผ์ƒ ๋Œ€ํ™” "
topic = input("๋Œ€ํ™” ์ฃผ์ œ๋ฅผ ์ •ํ•ด์ฃผ์„ธ์š” (e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...) :")
if topic == "":
topic = random.choice(['์—ฌ๊ฐ€ ์ƒํ™œ', '์‹œ์‚ฌ/๊ต์œก', '๋ฏธ์šฉ๊ณผ ๊ฑด๊ฐ•', '์‹์Œ๋ฃŒ', '์ƒ๊ฑฐ๋ž˜(์‡ผํ•‘)', '์ผ๊ณผ ์ง์—…', '์ฃผ๊ฑฐ์™€ ์ƒํ™œ', '๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„', 'ํ–‰์‚ฌ'])
print(topic)
info += topic + "<sep>"
def ask_info(who, ment):
print(ment)
text = who + ":"
addr = input("์–ด๋”” ์‚ฌ์„ธ์š”? (e.g. ์„œ์šธํŠน๋ณ„์‹œ, ์ œ์ฃผ๋„, etc...) :").strip()
if addr == "":
addr = random.choice(['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
print(addr)
text += addr + " "
age = input("๋‚˜์ด๊ฐ€? (e.g. 20๋Œ€, 70๋Œ€ ์ด์ƒ, etc...) :").strip()
if age == "":
age = random.choice(['20๋Œ€', '30๋Œ€', '50๋Œ€', '20๋Œ€ ๋ฏธ๋งŒ', '60๋Œ€', '40๋Œ€', '70๋Œ€ ์ด์ƒ'])
print(age)
text += age + " "
sex = input("์„ฑ๋ณ„์ด? (e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc... (?)) :").strip()
if sex == "":
sex = random.choice(['๋‚จ์„ฑ', '์—ฌ์„ฑ'])
print(sex)
text += sex + "<sep>"
return text
info += ask_info(who="P01", ment=f"\n๋‹น์‹ ์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”.\n")
info += ask_info(who="P02", ment=f"\n์ฑ—๋ด‡์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”.\n")
pp = info.replace('<sep>', '\n')
print(
f"\n----------------\n"
f"<์ž…๋ ฅ ์ •๋ณด ํ™•์ธ> (P01 : ๋‹น์‹ , P02 : ์ฑ—๋ด‡)\n"
f"{pp}"
f"----------------\n"
f"๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•˜๊ณ  ์‹ถ์œผ๋ฉด ์–ธ์ œ๋“ ์ง€ 'end' ๋ผ๊ณ  ๋งํ•ด์ฃผ์„ธ์š”~\n"
)
talk = []
switch = True
switch2 = True
while True:
inp = "P01<sos>"
myinp = input("๋‹น์‹  : ")
if myinp == "end":
print("๋Œ€ํ™” ์ข…๋ฃŒ!")
break
inp += myinp + "<eos>"
talk.append(inp)
talk.append("P02<sos>")
while True:
now_inp = info + "".join(talk)
inpu = tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
seq_len = inpu.input_ids.size(1)
if seq_len > 512 * 0.8 and switch:
print(
f"<์ฃผ์˜> ํ˜„์žฌ ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๊ณง ์ตœ๋Œ€ ๊ธธ์ด์— ๋„๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ({seq_len} / 512)"
)
switch = False
if seq_len >= 512 and switch2:
print("<์ฃผ์˜> ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์กŒ๊ธฐ ๋•Œ๋ฌธ์—, ์ดํ›„ ๋Œ€ํ™”๋Š” ๋งจ ์•ž์˜ ๋ฐœํ™”๋ฅผ ์กฐ๊ธˆ์”ฉ ์ง€์šฐ๋ฉด์„œ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค.")
talk = talk[1:]
switch2 = False
else:
break
out = model.generate(
inputs=inpu.input_ids.cuda(),
attention_mask=inpu.attention_mask.cuda(),
max_length=512,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.encode('<eos>')[0]
)
output = tokenizer.batch_decode(out)
print("์ฑ—๋ด‡ : " + output[0][len(now_inp):-5])
talk[-1] += output[0][len(now_inp):]
again = input(f"๋‹ค๋ฅธ ๋Œ€ํ™”๋ฅผ ์‹œ์ž‘ํ• ๊นŒ์š”? (yes : ์ƒˆ๋กœ์šด ์‹œ์ž‘, no : ์ข…๋ฃŒ) :")
if again == "no":
break