Unggi's picture
layout complete
8bfc8f9
import numpy as np
#import itertools
import gradio as gr
import pandas as pd
# make function using import pip to install torch
import pip
#pip.main(['install', 'torch'])
#pip.main(['install', 'transformers'])
import torch
import transformers
import random
# saved_model
def load_model():
pretrained_model_name = "skt/kogpt2-base-v2"
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(
pretrained_model_name, # kogpt2
# kogpt๋Š” ์‚ฌ์ „์— ํ† ํฐ์„ ์ง€์ •ํ•ด์ฃผ์ง€ ์•Š์œผ๋ฉด, None ๊ฐ’์œผ๋กœ ๋ฐ˜์˜๋˜์–ด์žˆ์Œ
# ๋ฐ˜๋“œ์‹œ ์ง€์ •ํ•ด์ฃผ์–ด์•ผ ํ•จ
bos_token='</s>', eos_token='</s>', unk_token='<unk>',
pad_token='<pad>', mask_token='<mask>'
)
model = transformers.GPT2LMHeadModel.from_pretrained(
pretrained_model_name # kogpt2
)
model.resize_token_embeddings( len(tokenizer) )
return model, tokenizer
# main
def inference(prompt):
model, tokenizer = load_model()
input_ids = tokenizer.encode(prompt, return_tensors="pt")
gen_ids = model.generate(input_ids,
max_length=128,
repetition_penalty=2.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
use_cache=True,
do_sample=True,
top_k=50,
top_p=0.92,
num_return_sequences=3
)
outputs = []
for gen_id in gen_ids:
output = tokenizer.decode(gen_id.tolist(), skip_special_tokens=True)
if prompt in output:
output = output.replace(prompt, '')
output = output.split('.')[0]
outputs.append(output)
return outputs
def restore(inputs, outputs):
result = inputs + outputs
return result
demo = gr.Blocks()
with demo:
gr.Markdown("## Kogpt2 Generation for Writing Education")
with gr.Row():
text_input = gr.Textbox(lines=20, label="Input")
with gr.Column():
with gr.Box():
text_output1 = gr.Textbox(lines=1, label="Output1")
output1_btn = gr.Button("Select ouput1")
with gr.Box():
text_output2 = gr.Textbox(lines=1, label="Output2")
output2_btn = gr.Button("Select ouput2")
with gr.Box():
text_output3 = gr.Textbox(lines=1, label="Output3")
output3_btn = gr.Button("Select ouput3")
text_button = gr.Button("Generate")
text_button.click(
inference,
inputs=[text_input],
outputs=[text_output1, text_output2, text_output3]
)
output1_btn.click(
restore,
inputs=[text_input, text_output1],
outputs=[text_input]
)
output2_btn.click(
restore,
inputs=[text_input, text_output2],
outputs=[text_input]
)
output3_btn.click(
restore,
inputs=[text_input, text_output3],
outputs=[text_input]
)
demo.launch() # launch(share=True)๋ฅผ ์„ค์ •ํ•˜๋ฉด ์™ธ๋ถ€์—์„œ ์ ‘์† ๊ฐ€๋Šฅํ•œ ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋จ