alex6095 commited on
Commit
962c41a
โ€ข
1 Parent(s): 7f9ccdf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import streamlit as st
4
+ import pandas as pd
5
+ from transformers import PreTrainedTokenizerFast, DistilBertForSequenceClassification, BartForConditionalGeneration
6
+
7
+ from tokenization_kobert import KoBertTokenizer
8
+ from tokenizers import SentencePieceBPETokenizer
9
+
10
+
11
+ @st.cache(allow_output_mutation=True)
12
+ def get_topic():
13
+ model = DistilBertForSequenceClassification.from_pretrained(
14
+ 'alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9)
15
+ model.eval()
16
+ tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
17
+ return model, tokenizer
18
+
19
+ @st.cache(allow_output_mutation=True)
20
+ def get_date():
21
+ model = BartForConditionalGeneration.from_pretrained('alex6095/SanctiMoly-Bart')
22
+ model.eval()
23
+ tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization')
24
+ return model, tokenizer
25
+
26
+ class RegexSubstitution(object):
27
+ """Regex substitution class for transform"""
28
+ def __init__(self, regex, sub=''):
29
+ if isinstance(regex, re.Pattern):
30
+ self.regex = regex
31
+ else:
32
+ self.regex = re.compile(regex)
33
+ self.sub = sub
34
+ def __call__(self, target):
35
+ if isinstance(target, list):
36
+ return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
37
+ else:
38
+ return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
39
+
40
+ default_text = '''์งˆ๋ณ‘๊ด€๋ฆฌ์ฒญ์€ 23์ผ ์ง€๋ฐฉ์ž์น˜๋‹จ์ฒด๊ฐ€ ๋ณด๊ฑด๋‹น๊ตญ๊ณผ ํ˜‘์˜ ์—†์ด ๋‹จ๋…์œผ๋กœ ์ธํ”Œ๋ฃจ์—”์ž(๋…๊ฐ) ๋ฐฑ์‹  ์ ‘์ข… ์ค‘๋‹จ์„ ๊ฒฐ์ •ํ•ด์„œ๋Š” ์•ˆ ๋œ๋‹ค๋Š” ์ž…์žฅ์„ ๋ฐํ˜”๋‹ค.
41
+ ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  ์ฐธ๊ณ ์ž๋ฃŒ๋ฅผ ๋ฐฐํฌํ•˜๊ณ  โ€œํ–ฅํ›„ ์ „์ฒด ๊ตญ๊ฐ€ ์˜ˆ๋ฐฉ์ ‘์ข…์‚ฌ์—…์ด ์ฐจ์งˆ ์—†์ด ์ง„ํ–‰๋˜๋„๋ก ์ง€์ž์ฒด๊ฐ€ ์ž์ฒด์ ์œผ๋กœ ์ ‘์ข… ์œ ๋ณด ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜์ง€ ์•Š๋„๋ก ์•ˆ๋‚ด๋ฅผ ํ–ˆ๋‹คโ€๊ณ  ์„ค๋ช…ํ–ˆ๋‹ค.
42
+ ๋…๊ฐ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•œ ํ›„ ๊ณ ๋ น์ธต์„ ์ค‘์‹ฌ์œผ๋กœ ์ „๊ตญ์—์„œ ์‚ฌ๋ง์ž๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์„œ์šธ ์˜๋“ฑํฌ๊ตฌ๋ณด๊ฑด์†Œ๋Š” ์ „๋‚ , ๊ฒฝ๋ถ ํฌํ•ญ์‹œ๋Š” ์ด๋‚  ๊ด€๋‚ด ์˜๋ฃŒ๊ธฐ๊ด€์— ์ ‘์ข…์„ ๋ณด๋ฅ˜ํ•ด๋‹ฌ๋ผ๋Š” ๊ณต๋ฌธ์„ ๋‚ด๋ ค๋ณด๋ƒˆ๋‹ค. ์ด๋Š” ์˜ˆ๋ฐฉ์ ‘์ข…๊ณผ ์‚ฌ๋ง ๊ฐ„ ์ง์ ‘์  ์—ฐ๊ด€์„ฑ์ด ๋‚ฎ์•„ ์ ‘์ข…์„ ์ค‘๋‹จํ•  ์ƒํ™ฉ์€ ์•„๋‹ˆ๋ผ๋Š” ์งˆ๋ณ‘์ฒญ์˜ ํŒ๋‹จ๊ณผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์ด๋‹ค.
43
+ ์งˆ๋ณ‘์ฒญ์€ ์ง€๋‚œ 21์ผ ์ „๋ฌธ๊ฐ€ ๋“ฑ์ด ์ฐธ์—ฌํ•œ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜โ€™์˜ ๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋…๊ฐ ์˜ˆ๋ฐฉ์ ‘์ข… ์‚ฌ์—…์„ ์ผ์ •๋Œ€๋กœ ์ง„ํ–‰ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ํŠนํžˆ ๊ณ ๋ น ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด, ์ž„์‹ ๋ถ€ ๋“ฑ ๋…๊ฐ ๊ณ ์œ„ํ—˜๊ตฐ์€ ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ ํ•ฉ๋ณ‘์ฆ ํ”ผํ•ด๊ฐ€ ํด ์ˆ˜ ์žˆ๋‹ค๋ฉด์„œ ์ ‘์ข…์„ ๋…๋ คํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ๋ฐœํ‘œ ์ดํ›„์—๋„ ์‚ฌ๋ง ๋ณด๊ณ ๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜ ํšŒ์˜โ€™์™€ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ์ „๋ฌธ์œ„์›ํšŒโ€™๋ฅผ ๊ฐœ์ตœํ•ด ๋…๊ฐ๋ฐฑ์‹ ๊ณผ ์‚ฌ๋ง ๊ฐ„ ๊ด€๋ จ์„ฑ, ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ์—ฌ๋ถ€ ๋“ฑ์— ๋Œ€ํ•ด ๋‹ค์‹œ ๊ฒฐ๋ก  ๋‚ด๋ฆฌ๊ธฐ๋กœ ํ–ˆ๋‹ค. ํšŒ์˜ ๊ฒฐ๊ณผ๋Š” ์ด๋‚  ์˜คํ›„ 7์‹œ ๋„˜์–ด ๋ฐœํ‘œ๋  ์˜ˆ์ •์ด๋‹ค.
44
+ '''
45
+ topics_raw = ['IT/๊ณผํ•™', '๊ฒฝ์ œ', '๋ฌธํ™”', '๋ฏธ์šฉ/๊ฑด๊ฐ•', '์‚ฌํšŒ', '์ƒํ™œ', '์Šคํฌ์ธ ', '์—ฐ์˜ˆ', '์ •์น˜']
46
+
47
+
48
+ #topic_model, topic_tokenizer = get_topic()
49
+ #date_model, date_tokenizer = get_date()
50
+
51
+ name = st.side_bar.selectbox('Model', ['Topic Classification', 'Date Prediction'])
52
+ if name == 'Topic Classification':
53
+ title = 'News Topic Classification'
54
+ model, tokenizer = get_topic()
55
+ elif name == 'Date Prediction':
56
+ title = 'News Date prediction'
57
+ model, tokenizer = get_date()
58
+
59
+ st.title(title)
60
+
61
+ text = st.text_area("Input news :", value=default_text)
62
+ st.markdown("## Original News Data")
63
+ st.write(text)
64
+
65
+ if name == 'Topic Classification':
66
+ st.markdown("## Predict Topic")
67
+ col1, col2 = st.columns(2)
68
+ if text:
69
+ with st.spinner('processing..'):
70
+ text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
71
+ encoded_dict = tokenizer(
72
+ text=text,
73
+ add_special_tokens=True,
74
+ max_length=512,
75
+ truncation=True,
76
+ return_tensors='pt',
77
+ return_length=True
78
+ )
79
+ input_ids = encoded_dict['input_ids']
80
+ input_ids_len = encoded_dict['length'].unsqueeze(0)
81
+ attn_mask = torch.arange(input_ids.size(1))
82
+ attn_mask = attn_mask[None, :] < input_ids_len[:, None]
83
+ outputs = model(input_ids=input_ids, attention_mask=attn_mask)
84
+ _, preds = torch.max(outputs.logits, 1)
85
+ col1.write(topics_raw[preds.squeeze(0)])
86
+ softmax = torch.nn.Softmax(dim=1)
87
+ prob = softmax(outputs.logits).squeeze(0).detach()
88
+ chart_data = pd.DataFrame({
89
+ 'Topic': topics_raw,
90
+ 'Probability': prob
91
+ })
92
+ chart_data = chart_data.set_index('Topic')
93
+ col2.bar_chart(chart_data)
94
+
95
+ elif name == 'Date Prediction':
96
+ st.markdown("## Predict Date")
97
+ if text:
98
+ with st.spinner('processing..'):
99
+ text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
100
+ raw_input_ids = tokenizer.encode(text)
101
+ input_ids = [tokenizer.bos_token_id] + \
102
+ raw_input_ids + [tokenizer.eos_token_id]
103
+ outputs = model.generate(torch.tensor([input_ids]),
104
+ early_stopping=True,
105
+ repetition_penalty=2.0,
106
+ do_sample=True, #์ƒ˜ํ”Œ๋ง ์ „๋žต ์‚ฌ์šฉ
107
+ max_length=50, # ์ตœ๋Œ€ ๋””์ฝ”๋”ฉ ๊ธธ์ด๋Š” 50
108
+ top_k=50, # ํ™•๋ฅ  ์ˆœ์œ„๊ฐ€ 50์œ„ ๋ฐ–์ธ ํ† ํฐ์€ ์ƒ˜ํ”Œ๋ง์—์„œ ์ œ์™ธ
109
+ top_p=0.95, # ๋ˆ„์  ํ™•๋ฅ ์ด 95%์ธ ํ›„๋ณด์ง‘ํ•ฉ์—์„œ๋งŒ ์ƒ์„ฑ
110
+ num_return_sequences=3 #3๊ฐœ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋””์ฝ”๋”ฉํ•ด๋‚ธ๋‹ค
111
+ )
112
+ for output in outputs:
113
+ pred_print = tokenizer.decode(output.squeeze().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
114
+ st.write(pred_print)