alex6095 commited on
Commit
c1eabe2
1 Parent(s): 43b149d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import re
3
  import streamlit as st
 
4
 
5
  from transformers import DistilBertForSequenceClassification
6
  from tokenization_kobert import KoBertTokenizer
@@ -8,12 +9,15 @@ from tokenization_kobert import KoBertTokenizer
8
 
9
  tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
10
 
 
11
  @st.cache(allow_output_mutation=True)
12
  def get_model():
13
- model = DistilBertForSequenceClassification.from_pretrained('alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9)
 
14
  model.eval()
15
  return model
16
 
 
17
  class RegexSubstitution(object):
18
  """Regex substitution class for transform"""
19
 
@@ -23,10 +27,10 @@ class RegexSubstitution(object):
23
  else:
24
  self.regex = re.compile(regex)
25
  self.sub = sub
26
-
27
  def __call__(self, target):
28
  if isinstance(target, list):
29
- return [ self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target ]
30
  else:
31
  return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
32
 
@@ -41,21 +45,23 @@ topics_raw = ['IT/과학', '경제', '문화', '미용/건강', '사회', '생
41
 
42
  model = get_model()
43
 
44
- st.title("Topic estimate Model Test")
45
 
46
  text = st.text_area("Input news :", value=default_text)
47
 
48
  st.markdown("## Original News Data")
49
  st.write(text)
50
 
 
 
 
51
  if text:
52
- st.markdown("## Predict Topic")
53
  with st.spinner('processing..'):
54
  text = RegexSubstitution(r'\([^()]+\)|[<>\'"△▲□■]')(text)
55
  encoded_dict = tokenizer(
56
  text=text,
57
  add_special_tokens=True,
58
- max_length = 512,
59
  truncation=True,
60
  return_tensors='pt',
61
  return_length=True
@@ -68,4 +74,13 @@ if text:
68
 
69
  _, preds = torch.max(outputs.logits, 1)
70
 
71
- st.write(topics_raw[preds.squeeze(0)])
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import re
3
  import streamlit as st
4
+ import pandas as pd
5
 
6
  from transformers import DistilBertForSequenceClassification
7
  from tokenization_kobert import KoBertTokenizer
 
9
 
10
  tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
11
 
12
+
13
  @st.cache(allow_output_mutation=True)
14
  def get_model():
15
+ model = DistilBertForSequenceClassification.from_pretrained(
16
+ 'alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9)
17
  model.eval()
18
  return model
19
 
20
+
21
  class RegexSubstitution(object):
22
  """Regex substitution class for transform"""
23
 
 
27
  else:
28
  self.regex = re.compile(regex)
29
  self.sub = sub
30
+
31
  def __call__(self, target):
32
  if isinstance(target, list):
33
+ return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
34
  else:
35
  return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
36
 
 
45
 
46
  model = get_model()
47
 
48
+ st.title("News Topic Classification")
49
 
50
  text = st.text_area("Input news :", value=default_text)
51
 
52
  st.markdown("## Original News Data")
53
  st.write(text)
54
 
55
+ st.markdown("## Predict Topic")
56
+ col1, col2 = st.columns(2)
57
+
58
  if text:
 
59
  with st.spinner('processing..'):
60
  text = RegexSubstitution(r'\([^()]+\)|[<>\'"△▲□■]')(text)
61
  encoded_dict = tokenizer(
62
  text=text,
63
  add_special_tokens=True,
64
+ max_length=512,
65
  truncation=True,
66
  return_tensors='pt',
67
  return_length=True
 
74
 
75
  _, preds = torch.max(outputs.logits, 1)
76
 
77
+ col1.write(topics_raw[preds.squeeze(0)])
78
+ softmax = torch.nn.Softmax(dim=1)
79
+ prob = softmax(outputs.logits).squeeze(0).detach()
80
+ chart_data = pd.DataFrame({
81
+ 'Topic': topics_raw,
82
+ 'Probability': prob
83
+ })
84
+ chart_data = chart_data.set_index('Topic')
85
+
86
+ col2.bar_chart(chart_data)