File size: 2,109 Bytes
4f76eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel, PreTrainedTokenizerFast
import numpy as np


model = GPT2LMHeadModel.from_pretrained("jason9693/soongsil-univ-gpt-v1")
tokenizer = PreTrainedTokenizerFast.from_pretrained("jason9693/soongsil-univ-gpt-v1")

category_map = {
    "์ˆญ์‹ค๋Œ€ ์—ํƒ€": "<unused5>",
    "๋ชจ๋‘์˜ ์—ฐ์• ": "<unused3>",
    "๋Œ€ํ•™์ƒ ์žก๋‹ด๋ฐฉ": "<unused4>"
}

st.markdown("""# University Community KoGPT2 : ์ˆญ์‹ค๋Œ€ ์—๋ธŒ๋ฆฌํƒ€์ž„๋ด‡

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1p6DIxsesi3eJNPwFwvMw0MeM5LkSGoPW?usp=sharing)	[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jason9693/UCK-GPT2/issues)	![GitHub](https://img.shields.io/github/license/jason9693/UCK-GPT2)

## ๋Œ€ํ•™ ์ปค๋ฎค๋‹ˆํ‹ฐ ๊ฒŒ์‹œ๊ธ€ ์ƒ์„ฑ๊ธฐ

SKT-AI์—์„œ ๊ณต๊ฐœํ•œ [KoGPT2](https://github.com/SKT-AI/KoGPT2) ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•˜์—ฌ ๋Œ€ํ•™ ์ปค๋ฎค๋‹ˆํ‹ฐ ๊ฒŒ์‹œ๊ธ€์„ ์ƒ์„ฑํ•˜๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด ์—๋ธŒ๋ฆฌํƒ€์ž„, ์บ ํผ์Šคํ”ฝ ๋ฐ์ดํ„ฐ 22๋งŒ๊ฐœ๋ฅผ ์ด์šฉํ•ด์„œ ํ•™์Šต์„ ์ง„ํ–‰ํ–ˆ์œผ๋ฉฐ, ํ•™์Šต์—๋Š” ๋Œ€๋žต **3์ผ**์ •๋„ ์†Œ์š”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

* [GPT ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ ๋งํฌ](https://www.notion.so/Improve-Language-Understanding-by-Generative-Pre-Training-GPT-afb4b5ef6e984961ac022b700c152b6b)

## ์‹œ์—ฐํ•˜๊ธฐ
""")


seed = st.text_input("Seed", "์กฐ๋งŒ์‹ ๊ธฐ๋…๊ด€")
category = st.selectbox("Category", list(category_map.keys()))
go = st.button("Generate")


st.markdown("## ์ƒ์„ฑ ๊ฒฐ๊ณผ")
if go:
    input_context = category_map[category] + seed
    input_ids = tokenizer(input_context, return_tensors="pt")
    outputs = model.generate(
        input_ids=input_ids["input_ids"],
        max_length=250,
        num_return_sequences=1,
        no_repeat_ngram_size=3,
        repetition_penalty=2.0,
        do_sample=True,
        use_cache=True,
        eos_token_id=tokenizer.eos_token_id
    )
    st.write(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace("<unused2>", "\n"))