AlanTsai-0329's picture
Upload 14 files
6a9c006
raw
history blame
5.87 kB
import re
import jieba
import jieba.analyse
import accelerate
import numpy as np
import pandas as pd
import streamlit as st
import matplotlib.pyplot as plt
import torch.nn.functional as F
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
from wordcloud import WordCloud
from transformers import AutoModelForSequenceClassification, BertTokenizerFast, pipeline
accelerator = accelerate.Accelerator(cpu=True)
class LoadException(Exception):
...
class LoadModelException(Exception):
...
class LoadTokenizerException(Exception):
...
class DIR:
DICT_DIR = Path("pages/docs/dict")
MODEL_DIR = Path("pages/docs/model_param")
CLASSIFIER_MODEL_DIR = Path(f"{MODEL_DIR}/board_classification_model")
SENTIMENT_MODEL_DIR = Path(f"{MODEL_DIR}/sentiment_analysis_model")
SUMMARIZATION_MODEL_DIR = Path(f"{MODEL_DIR}/summarization_model")
class Bert_Classify_Model:
def __init__(self):
self.tokenizer_loaded = False
self.model_loaded = False
def load_model(self):
try:
self.tokenizer = BertTokenizerFast.from_pretrained(
pretrained_model_name_or_path=DIR.CLASSIFIER_MODEL_DIR,
local_files_only=True
)
self.tokenizer_loaded = True
except LoadTokenizerException:
raise "Tokenizer not loaded."
try:
self.model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=DIR.CLASSIFIER_MODEL_DIR,
local_files_only=True,
num_labels=4
)
self.model_loaded = True
except LoadModelException:
raise "Model not loaded."
@staticmethod
def __make_output(outputs):
id2label = {
"0": "C_Chat",
"1": "Gossiping",
"2": "HatePolotics",
"3": "Marginalman"
}
pred_prob = F.softmax(outputs.logits)
pred_prob_df = (
pd.DataFrame({
"版面": id2label.values(),
"機率": pred_prob[0, :].detach().numpy()
})
.sort_values(by="機率", ascending=False)
)
return pred_prob_df
def predict(self, text):
if (not self.tokenizer_loaded) and (not self.model_loaded):
raise LoadException("Not loaded.")
token_text = self.tokenizer(
text,
padding=True,
truncation=True,
return_tensors='pt'
)
outputs = self.model(**token_text)
result = self.__make_output(outputs)
return result
class Sentiment_Model:
def __init__(self):
self.model_loaded = False
def load_model(self):
try:
self.model = pipeline(
"sentiment-analysis",
DIR.SENTIMENT_MODEL_DIR,
)
self.model_loaded = True
except LoadModelException:
raise "Model not loaded."
def run_sentiment(self, text):
if not self.model_loaded:
raise LoadModelException("model not loaded.")
outputs = self.model(text)
return outputs
class Summarization_Model:
def __init__(self):
self.model_loaded = False
def load_model(self):
try:
self.model = pipeline(
"summarization",
DIR.SUMMARIZATION_MODEL_DIR
)
self.model_loaded = True
except LoadModelException:
raise "Model not loaded."
self.model_loaded = True
@staticmethod
def __make_output(outputs):
return outputs[0]["summary_text"]
def run_summarize(self, text):
if not self.model_loaded:
raise LoadModelException("model not loaded.")
outputs = self.model(text, max_length=1024)
result = self.__make_output(outputs)
return result
class WordCloudDrawer:
def __init__(self):
jieba.set_dictionary(f'{DIR.DICT_DIR}/dict.txt') # 繁中辭典
jieba.analyse.set_stop_words(f'{DIR.DICT_DIR}/stopdict.txt') # 設置停用詞辭典
self.punctuation_list = []
with open(f'{DIR.DICT_DIR}/punctuations.txt', 'r', encoding='utf-8-sig') as f2:
for data in f2.readlines():
self.punctuation_list.append(data.strip())
def __filter(self, word): # 過濾特殊符號
if word in self.punctuation_list:
return False
else:
return True
def __preprocess(self, text): # 去除停用詞並斷詞
return list(filter(self.__filter, jieba.analyse.extract_tags(text, topK=None, withWeight=False, allowPOS=())))
def word_cloud(self, text, num_words):
processed_text = self.__preprocess(text)
''' TF-IDF '''
vectorizer = TfidfVectorizer(smooth_idf=True)
tfidf = vectorizer.fit_transform(processed_text)
data = {
'word': vectorizer.get_feature_names_out(), # 修改此行
'tfidf': tfidf.toarray().sum(axis=0).tolist()
}
word_score = pd.DataFrame(data).sort_values(by='tfidf', ascending=False)
top_words = word_score.sort_values(by='tfidf', ascending=False)[:num_words]
d = dict(zip(top_words['word'].to_list(), top_words['tfidf'].to_list()))
wc = WordCloud(
background_color='white',
collocations=False,
font_path=f'{DIR.DICT_DIR}/SimHei.ttf',
max_font_size=48
)
wc= wc.generate_from_frequencies(d)
return wc
# fig, ax = plt.subplots(figsize = (12, 8))
# plt.imshow(wc)
# plt.axis("off")
# st.pyplot(plt.gcf())
# plt.imshow(wc)
# plt.show()
# st.pyplot()