Spaces:
Runtime error
Runtime error
import numpy as np | |
#import itertools | |
import gradio as gr | |
import pandas as pd | |
# make function using import pip to install torch | |
import pip | |
#pip.main(['install', 'torch']) | |
#pip.main(['install', 'transformers']) | |
import torch | |
import transformers | |
import random | |
# saved_model | |
def load_model(): | |
pretrained_model_name = "skt/kogpt2-base-v2" | |
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained( | |
pretrained_model_name, # kogpt2 | |
# kogpt๋ ์ฌ์ ์ ํ ํฐ์ ์ง์ ํด์ฃผ์ง ์์ผ๋ฉด, None ๊ฐ์ผ๋ก ๋ฐ์๋์ด์์ | |
# ๋ฐ๋์ ์ง์ ํด์ฃผ์ด์ผ ํจ | |
bos_token='</s>', eos_token='</s>', unk_token='<unk>', | |
pad_token='<pad>', mask_token='<mask>' | |
) | |
model = transformers.GPT2LMHeadModel.from_pretrained( | |
pretrained_model_name # kogpt2 | |
) | |
model.resize_token_embeddings( len(tokenizer) ) | |
return model, tokenizer | |
# main | |
def inference(prompt): | |
model, tokenizer = load_model() | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
gen_ids = model.generate(input_ids, | |
max_length=128, | |
repetition_penalty=2.0, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
use_cache=True, | |
do_sample=True, | |
top_k=50, | |
top_p=0.92, | |
num_return_sequences=3 | |
) | |
outputs = [] | |
for gen_id in gen_ids: | |
output = tokenizer.decode(gen_id.tolist(), skip_special_tokens=True) | |
if prompt in output: | |
output = output.replace(prompt, '') | |
output = output.split('.')[0] | |
outputs.append(output) | |
return outputs | |
def restore(inputs, outputs): | |
result = inputs + outputs | |
return result | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("## Kogpt2 Generation for Writing Education") | |
with gr.Row(): | |
text_input = gr.Textbox(lines=20, label="Input") | |
with gr.Column(): | |
with gr.Box(): | |
text_output1 = gr.Textbox(lines=1, label="Output1") | |
output1_btn = gr.Button("Select ouput1") | |
with gr.Box(): | |
text_output2 = gr.Textbox(lines=1, label="Output2") | |
output2_btn = gr.Button("Select ouput2") | |
with gr.Box(): | |
text_output3 = gr.Textbox(lines=1, label="Output3") | |
output3_btn = gr.Button("Select ouput3") | |
text_button = gr.Button("Generate") | |
text_button.click( | |
inference, | |
inputs=[text_input], | |
outputs=[text_output1, text_output2, text_output3] | |
) | |
output1_btn.click( | |
restore, | |
inputs=[text_input, text_output1], | |
outputs=[text_input] | |
) | |
output2_btn.click( | |
restore, | |
inputs=[text_input, text_output2], | |
outputs=[text_input] | |
) | |
output3_btn.click( | |
restore, | |
inputs=[text_input, text_output3], | |
outputs=[text_input] | |
) | |
demo.launch() # launch(share=True)๋ฅผ ์ค์ ํ๋ฉด ์ธ๋ถ์์ ์ ์ ๊ฐ๋ฅํ ๋งํฌ๊ฐ ์์ฑ๋จ |