File size: 6,190 Bytes
962c41a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20f3361
 
4e2e011
962c41a
27801b2
4e2e011
962c41a
 
20f3361
962c41a
 
20f3361
962c41a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef0c66
962c41a
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef0c66
962c41a
5ef0c66
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import re
import streamlit as st
import pandas as pd
from transformers import PreTrainedTokenizerFast, DistilBertForSequenceClassification, BartForConditionalGeneration

from tokenization_kobert import KoBertTokenizer
from tokenizers import SentencePieceBPETokenizer


@st.cache(allow_output_mutation=True)
def get_topic():
    model = DistilBertForSequenceClassification.from_pretrained(
        'alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9)
    model.eval()
    tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
    return model, tokenizer

@st.cache(allow_output_mutation=True)
def get_date():
    model = BartForConditionalGeneration.from_pretrained('alex6095/SanctiMoly-Bart')
    model.eval()
    tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization')
    return model, tokenizer

class RegexSubstitution(object):
    """Regex substitution class for transform"""
    def __init__(self, regex, sub=''):
        if isinstance(regex, re.Pattern):
            self.regex = regex
        else:
            self.regex = re.compile(regex)
        self.sub = sub
    def __call__(self, target):
        if isinstance(target, list):
            return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
        else:
            return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
            
default_text = '''์งˆ๋ณ‘๊ด€๋ฆฌ์ฒญ์€ 23์ผ ์ง€๋ฐฉ์ž์น˜๋‹จ์ฒด๊ฐ€ ๋ณด๊ฑด๋‹น๊ตญ๊ณผ ํ˜‘์˜ ์—†์ด ๋‹จ๋…์œผ๋กœ ์ธํ”Œ๋ฃจ์—”์ž(๋…๊ฐ) ๋ฐฑ์‹  ์ ‘์ข… ์ค‘๋‹จ์„ ๊ฒฐ์ •ํ•ด์„œ๋Š” ์•ˆ ๋œ๋‹ค๋Š” ์ž…์žฅ์„ ๋ฐํ˜”๋‹ค.
    ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  ์ฐธ๊ณ ์ž๋ฃŒ๋ฅผ ๋ฐฐํฌํ•˜๊ณ  โ€œํ–ฅํ›„ ์ „์ฒด ๊ตญ๊ฐ€ ์˜ˆ๋ฐฉ์ ‘์ข…์‚ฌ์—…์ด ์ฐจ์งˆ ์—†์ด ์ง„ํ–‰๋˜๋„๋ก ์ง€์ž์ฒด๊ฐ€ ์ž์ฒด์ ์œผ๋กœ ์ ‘์ข… ์œ ๋ณด ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜์ง€ ์•Š๋„๋ก ์•ˆ๋‚ด๋ฅผ ํ–ˆ๋‹คโ€๊ณ  ์„ค๋ช…ํ–ˆ๋‹ค.
    ๋…๊ฐ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•œ ํ›„ ๊ณ ๋ น์ธต์„ ์ค‘์‹ฌ์œผ๋กœ ์ „๊ตญ์—์„œ ์‚ฌ๋ง์ž๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์„œ์šธ ์˜๋“ฑํฌ๊ตฌ๋ณด๊ฑด์†Œ๋Š” ์ „๋‚ , ๊ฒฝ๋ถ ํฌํ•ญ์‹œ๋Š” ์ด๋‚  ๊ด€๋‚ด ์˜๋ฃŒ๊ธฐ๊ด€์— ์ ‘์ข…์„ ๋ณด๋ฅ˜ํ•ด๋‹ฌ๋ผ๋Š” ๊ณต๋ฌธ์„ ๋‚ด๋ ค๋ณด๋ƒˆ๋‹ค. ์ด๋Š” ์˜ˆ๋ฐฉ์ ‘์ข…๊ณผ ์‚ฌ๋ง ๊ฐ„ ์ง์ ‘์  ์—ฐ๊ด€์„ฑ์ด ๋‚ฎ์•„ ์ ‘์ข…์„ ์ค‘๋‹จํ•  ์ƒํ™ฉ์€ ์•„๋‹ˆ๋ผ๋Š” ์งˆ๋ณ‘์ฒญ์˜ ํŒ๋‹จ๊ณผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์ด๋‹ค.
    ์งˆ๋ณ‘์ฒญ์€ ์ง€๋‚œ 21์ผ ์ „๋ฌธ๊ฐ€ ๋“ฑ์ด ์ฐธ์—ฌํ•œ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜โ€™์˜ ๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋…๊ฐ ์˜ˆ๋ฐฉ์ ‘์ข… ์‚ฌ์—…์„ ์ผ์ •๋Œ€๋กœ ์ง„ํ–‰ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ํŠนํžˆ ๊ณ ๋ น ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด, ์ž„์‹ ๋ถ€ ๋“ฑ ๋…๊ฐ ๊ณ ์œ„ํ—˜๊ตฐ์€ ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ ํ•ฉ๋ณ‘์ฆ ํ”ผํ•ด๊ฐ€ ํด ์ˆ˜ ์žˆ๋‹ค๋ฉด์„œ ์ ‘์ข…์„ ๋…๋ คํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ๋ฐœํ‘œ ์ดํ›„์—๋„ ์‚ฌ๋ง ๋ณด๊ณ ๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜ ํšŒ์˜โ€™์™€ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ์ „๋ฌธ์œ„์›ํšŒโ€™๋ฅผ ๊ฐœ์ตœํ•ด ๋…๊ฐ๋ฐฑ์‹ ๊ณผ ์‚ฌ๋ง ๊ฐ„ ๊ด€๋ จ์„ฑ, ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ์—ฌ๋ถ€ ๋“ฑ์— ๋Œ€ํ•ด ๋‹ค์‹œ ๊ฒฐ๋ก  ๋‚ด๋ฆฌ๊ธฐ๋กœ ํ–ˆ๋‹ค. ํšŒ์˜ ๊ฒฐ๊ณผ๋Š” ์ด๋‚  ์˜คํ›„ 7์‹œ ๋„˜์–ด ๋ฐœํ‘œ๋  ์˜ˆ์ •์ด๋‹ค.
'''
topics_raw = ['IT/๊ณผํ•™', '๊ฒฝ์ œ', '๋ฌธํ™”', '๋ฏธ์šฉ/๊ฑด๊ฐ•', '์‚ฌํšŒ', '์ƒํ™œ', '์Šคํฌ์ธ ', '์—ฐ์˜ˆ', '์ •์น˜']


topic_model, topic_tokenizer = get_topic()
date_model, date_tokenizer = get_date()
st.sidebar.header('Menu')

name = st.sidebar.selectbox('Model', ['Topic Classification', 'Date Prediction'])

if name == 'Topic Classification':
    title = 'News Topic Classification'
    model, tokenizer = topic_model, topic_tokenizer
elif name == 'Date Prediction':
    title = 'News Date prediction'
    model, tokenizer = date_model, date_tokenizer

st.title(title)

text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)

if name == 'Topic Classification':
    st.markdown("## Predict Topic")
    col1, col2 = st.columns(2)
    if text:
        with st.spinner('processing..'):
            text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
            encoded_dict = tokenizer(
                text=text,
                add_special_tokens=True,
                max_length=512,
                truncation=True,
                return_tensors='pt',
                return_length=True
            )
            input_ids = encoded_dict['input_ids']
            input_ids_len = encoded_dict['length'].unsqueeze(0)
            attn_mask = torch.arange(input_ids.size(1))
            attn_mask = attn_mask[None, :] < input_ids_len[:, None]
            outputs = model(input_ids=input_ids, attention_mask=attn_mask)
            _, preds = torch.max(outputs.logits, 1)
        col1.write(topics_raw[preds.squeeze(0)])
        softmax = torch.nn.Softmax(dim=1)
        prob = softmax(outputs.logits).squeeze(0).detach()
        chart_data = pd.DataFrame({
            'Topic': topics_raw,
            'Probability': prob
        })
        chart_data = chart_data.set_index('Topic')
        col2.bar_chart(chart_data)
        
elif name == 'Date Prediction':
    st.markdown("## Predict 3 possible Date")
    if text:
        with st.spinner('processing..'):
            text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
            raw_input_ids = tokenizer.encode(text)
            input_ids = [tokenizer.bos_token_id] + \
                raw_input_ids + [tokenizer.eos_token_id]
            outputs = model.generate(torch.tensor([input_ids]),
                                         early_stopping=True,
                                         do_sample=True, #์ƒ˜ํ”Œ๋ง ์ „๋žต ์‚ฌ์šฉ
                                         max_length=50, # ์ตœ๋Œ€ ๋””์ฝ”๋”ฉ ๊ธธ์ด๋Š” 50
                                         top_k=50, # ํ™•๋ฅ  ์ˆœ์œ„๊ฐ€ 50์œ„ ๋ฐ–์ธ ํ† ํฐ์€ ์ƒ˜ํ”Œ๋ง์—์„œ ์ œ์™ธ
                                         top_p=0.95, # ๋ˆ„์  ํ™•๋ฅ ์ด 95%์ธ ํ›„๋ณด์ง‘ํ•ฉ์—์„œ๋งŒ ์ƒ์„ฑ
                                         num_return_sequences=3 #3๊ฐœ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋””์ฝ”๋”ฉํ•ด๋‚ธ๋‹ค
                                         )
        pred_print = []
        for output in outputs:
            pred_print.append(tokenizer.decode(output.squeeze().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True))
        st.write(", ".join(pred_print))