bokobring commited on
Commit
9f90d6f
1 Parent(s): a6ecb1c

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. bert.ckpt +3 -0
  3. simhei.ttf +3 -0
  4. 主题建模.py +75 -0
  5. 情感.py +190 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ simhei.ttf filter=lfs diff=lfs merge=lfs -text
bert.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:008ac6636cfb61dfa98817a179c3c91ffd9cd30ac1f64ebb82332268b50eeab5
3
+ size 409130040
simhei.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67d5e7ed33195ead7334237ba5b71ca140ddfd75b0c4bbae19b895c18946a114
3
+ size 10050870
主题建模.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import jieba
3
+ import gensim
4
+ from collections import Counter
5
+ from wordcloud import WordCloud
6
+ import matplotlib.pyplot as plt
7
+ from termcolor import colored
8
+ import re
9
+ from tqdm import tqdm
10
+
11
+ def load_stopwords(filepath):
12
+ with open(filepath, 'r', encoding='utf-8') as f:
13
+ return {line.strip() for line in f}
14
+
15
+ def text_to_token(file_path, stopwords, user_dict=None):
16
+ if user_dict:
17
+ jieba.load_userdict(user_dict)
18
+ with open(file_path, 'r', encoding='utf-8') as f:
19
+ lines = f.readlines()
20
+ tokens = []
21
+ for i in tqdm(range(len(lines)), desc='Tokenizing'):
22
+ line = lines[i]
23
+ if line.strip():
24
+ line = re.sub(r'\[.*?\]|\s', '', line)
25
+ token = jieba.lcut(line)
26
+ token = [word for word in token if word not in stopwords and len(word) > 1]
27
+ if token:
28
+ tokens.append(token)
29
+ return tokens
30
+
31
+ def get_topics(tokens, num_topics, num_words):
32
+ dictionary = gensim.corpora.Dictionary(tokens)
33
+ corpus = [dictionary.doc2bow(token) for token in tokens]
34
+ lda = gensim.models.LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
35
+ topics = lda.print_topics(num_words=num_words)
36
+ return topics
37
+
38
+ def plot_wordcloud(tokens):
39
+ counter = Counter([word for token in tokens for word in token])
40
+ wordcloud = WordCloud(font_path='SimHei.ttf', width=800, height=600).generate_from_frequencies(counter)
41
+ plt.imshow(wordcloud, interpolation='bilinear')
42
+ plt.axis("off")
43
+ plt.show()
44
+
45
+ def main():
46
+ try:
47
+ file_path = input("请输入评论文件名或其路径:")
48
+ output_file = input("请输入输出文件名或其路径:")
49
+ if not os.path.exists(file_path):
50
+ print(colored('错误:文件不存在,请检查文件名或路径是否输入正确', 'red'))
51
+ return
52
+ num_topics = int(input("请输入生成主题的数量:"))
53
+ num_words = int(input("请输入每个主题下的词汇数:"))
54
+ user_dict = input("若有自定义词典文件,请输入其文件名或路径(没有则直接回车):")
55
+ stopwords = load_stopwords('stopwords.txt')
56
+ tokens = text_to_token(file_path, stopwords, user_dict)
57
+
58
+ print(colored("正在生成主题,请稍后...", 'green'))
59
+ topics = get_topics(tokens, num_topics, num_words)
60
+
61
+ with open(output_file, 'w', encoding='utf-8') as f:
62
+ for idx, topic in topics:
63
+ topic_str = colored(f'主题{idx + 1}:', 'blue') + topic.replace("+", "\n")
64
+ print(topic_str)
65
+ f.write(topic_str)
66
+
67
+ print(colored("正在生成词云图...", 'green'))
68
+ plot_wordcloud(tokens)
69
+
70
+ except Exception as e:
71
+ print(colored("出现错误:", 'red'), str(e))
72
+
73
+ if __name__ == '__main__':
74
+ print(colored("欢迎使用评论文字分析工具!\n", 'cyan'))
75
+ main()
情感.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import os
3
+ import re
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from transformers import BertModel, BertTokenizer
8
+ from tqdm import tqdm
9
+
10
+
11
+ class Config(object):
12
+ """配置参数"""
13
+ def __init__(self):
14
+ self.model_name = 'bert'
15
+ self.class_list = ['中性', '积极', '消极'] # 类别名单
16
+ self.save_path = 'bert.ckpt' # 模型训练结果
17
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
18
+
19
+ self.require_improvement = 1000
20
+ self.num_classes = len(self.class_list) # 类别数
21
+ self.num_epochs = 3
22
+ self.batch_size = 128
23
+ self.pad_size = 32
24
+ self.learning_rate = 5e-5
25
+ self.bert_path = 'bert-base-chinese'
26
+ self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
27
+ self.hidden_size = 768
28
+
29
+
30
+ class Model(nn.Module):
31
+ def __init__(self, config):
32
+ super(Model, self).__init__()
33
+ self.bert = BertModel.from_pretrained(config.bert_path)
34
+ for param in self.bert.parameters():
35
+ param.requires_grad = True
36
+ self.fc = nn.Linear(config.hidden_size, config.num_classes)
37
+
38
+ def forward(self, x):
39
+ context = x[0]
40
+ mask = x[2]
41
+ outputs = self.bert(context, attention_mask=mask)
42
+ pooled = outputs[1]
43
+ out = self.fc(pooled)
44
+ return out
45
+
46
+
47
+ def clean(text):
48
+ URL_REGEX = re.compile(
49
+ r'(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:\'".,<>?«»“”‘’]))',
50
+ re.IGNORECASE)
51
+ text = re.sub(URL_REGEX, "", text)
52
+ text = text.replace("转发微博", "")
53
+ text = re.sub(r"\s+", " ", text)
54
+ return text.strip()
55
+
56
+
57
+ def load_dataset(data, config):
58
+ pad_size = config.pad_size
59
+ contents = []
60
+ for line in data:
61
+ lin = clean(line)
62
+ token = config.tokenizer.tokenize(lin)
63
+ token = ['[CLS]'] + token
64
+ seq_len = len(token)
65
+ mask = []
66
+ token_ids = config.tokenizer.convert_tokens_to_ids(token)
67
+ if pad_size:
68
+ if len(token) < pad_size:
69
+ mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
70
+ token_ids += ([0] * (pad_size - len(token)))
71
+ else:
72
+ mask = [1] * pad_size
73
+ token_ids = token_ids[:pad_size]
74
+ seq_len = pad_size
75
+ contents.append((token_ids, int(0), seq_len, mask))
76
+ return contents
77
+
78
+
79
+ class DatasetIterater(object):
80
+ def __init__(self, batches, batch_size, device):
81
+ self.batch_size = batch_size
82
+ self.batches = batches
83
+ self.n_batches = len(batches) // batch_size
84
+ self.residue = False
85
+ if len(batches) % self.n_batches != 0:
86
+ self.residue = True
87
+ self.index = 0
88
+ self.device = device
89
+
90
+ def _to_tensor(self, datas):
91
+ x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
92
+ y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
93
+
94
+ seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
95
+ mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
96
+ return (x, seq_len, mask), y
97
+
98
+ def __next__(self):
99
+ if self.residue and self.index == self.n_batches:
100
+ batches = self.batches[self.index * self.batch_size: len(self.batches)]
101
+ self.index += 1
102
+ batches = self._to_tensor(batches)
103
+ return batches
104
+
105
+ elif self.index >= self.n_batches:
106
+ self.index = 0
107
+ raise StopIteration
108
+ else:
109
+ batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
110
+ self.index += 1
111
+ batches = self._to_tensor(batches)
112
+ return batches
113
+
114
+ def __iter__(self):
115
+ return self
116
+
117
+ def __len__(self):
118
+ if self.residue:
119
+ return self.n_batches + 1
120
+ else:
121
+ return self.n_batches
122
+
123
+
124
+ def build_iterator(dataset, config):
125
+ iter = DatasetIterater(dataset, 1, config.device)
126
+ return iter
127
+
128
+
129
+ def match_label(pred, config):
130
+ label_list = config.class_list
131
+ return label_list[pred]
132
+
133
+
134
+ def load_data(file_path, config):
135
+ if not os.path.isfile(file_path):
136
+ raise Exception(f"File not found: {file_path}")
137
+
138
+ with open(file_path, 'r', encoding='utf-8') as f:
139
+ lines = f.readlines()
140
+ data = load_dataset(lines, config)
141
+ return data, lines
142
+
143
+
144
+ def final_predict(config, model, data_iter):
145
+ map_location = lambda storage, loc: storage
146
+ if not os.path.isfile(config.save_path):
147
+ raise Exception(f"File not found: {config.save_path}")
148
+
149
+ model.load_state_dict(torch.load(config.save_path, map_location=map_location))
150
+ model.eval()
151
+ predict_all = np.array([], dtype=str)
152
+
153
+ with torch.no_grad():
154
+ for texts, _ in tqdm(data_iter, desc="Predicting"):
155
+ outputs = model(texts)
156
+ pred = torch.max(outputs.data, 1)[1].cpu().numpy()
157
+ predict_all = np.append(predict_all, [match_label(i, config) for i in pred])
158
+
159
+ return predict_all
160
+
161
+
162
+ def output_results(file_path, results, lines):
163
+ pos = sum(1 for result in results if result == '积极')
164
+ neg = sum(1 for result in results if result == '消极')
165
+ neu = sum(1 for result in results if result == '中性')
166
+ with open(file_path, 'w', encoding='utf-8') as f:
167
+ f.write(f'评论数量:{len(results)}\n')
168
+ f.write(f'积极评论数量:{pos}\n')
169
+ f.write(f'消极评论数量:{neg}\n')
170
+ f.write(f'中性评论数量:{neu}\n')
171
+ for line, result in zip(lines, results):
172
+ f.write(f'评论:{line.strip()} ,情感:{result}\n')
173
+
174
+
175
+ def predict(input_file, output_file):
176
+ config = Config()
177
+ model = Model(config).to(config.device)
178
+ test_data, lines = load_data(input_file, config)
179
+ test_iter = build_iterator(test_data, config)
180
+ results = final_predict(config, model, test_iter)
181
+ output_results(output_file, results, lines)
182
+
183
+
184
+ if __name__ == '__main__':
185
+ input_file = input("请输入评论文件的名字(相对于程序文件的路径):")
186
+ output_file = input("请输入结果输出文件的名字(相对于程序文件的路径):")
187
+ try:
188
+ predict(input_file, output_file)
189
+ except Exception as e:
190
+ print(f"Error: {e}")