sundea commited on
Commit
ac322b5
1 Parent(s): 0af9e6c

Upload 5 files

Browse files
Files changed (5) hide show
  1. run.py +52 -0
  2. test.py +120 -0
  3. train_eval.py +119 -0
  4. utils.py +156 -0
  5. utils_fasttext.py +169 -0
run.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: UTF-8
2
+ import time
3
+ import torch
4
+ import numpy as np
5
+ from train_eval import train, init_network
6
+ from importlib import import_module
7
+ import argparse
8
+
9
+ parser = argparse.ArgumentParser(description='Chinese Text Classification')
10
+ parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
11
+ parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
12
+ parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
13
+ args = parser.parse_args()
14
+
15
+
16
+ if __name__ == '__main__':
17
+ dataset = 'THUCNews' # 数据集
18
+
19
+ # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
20
+ embedding = 'embedding_SougouNews.npz'
21
+ if args.embedding == 'random':
22
+ embedding = 'random'
23
+ model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer
24
+ if model_name == 'FastText':
25
+ from utils_fasttext import build_dataset, build_iterator, get_time_dif
26
+ embedding = 'random'
27
+ else:
28
+ from utils import build_dataset, build_iterator, get_time_dif
29
+
30
+ x = import_module('models.' + model_name)
31
+ config = x.Config(dataset, embedding)
32
+ np.random.seed(1)
33
+ torch.manual_seed(1)
34
+ torch.cuda.manual_seed_all(1)
35
+ torch.backends.cudnn.deterministic = True # 保证每次结果一样
36
+
37
+ start_time = time.time()
38
+ print("Loading data...")
39
+ vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
40
+ train_iter = build_iterator(train_data, config)
41
+ dev_iter = build_iterator(dev_data, config)
42
+ test_iter = build_iterator(test_data, config)
43
+ time_dif = get_time_dif(start_time)
44
+ print("Time usage:", time_dif)
45
+
46
+ # train
47
+ config.n_vocab = len(vocab)
48
+ model = x.Model(config).to(config.device)
49
+ if model_name != 'Transformer':
50
+ init_network(model)
51
+ print(model.parameters)
52
+ train(config, model, train_iter, dev_iter, test_iter)
test.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from importlib import import_module
4
+
5
+ import gradio as gr
6
+ from tqdm import tqdm
7
+ import models.TextCNN
8
+ import torch
9
+ import pickle as pkl
10
+ from utils import build_dataset
11
+ classes=['finance','realty','stocks','education','science','society','politics','sports','game','entertainment']
12
+
13
+ MAX_VOCAB_SIZE = 10000 # 词表长度限制
14
+ UNK, PAD = '<UNK>', '<PAD>' # 未知字,padding符号
15
+ def build_vocab(file_path, tokenizer, max_size, min_freq):
16
+ vocab_dic = {}
17
+ with open(file_path, 'r', encoding='UTF-8') as f:
18
+ for line in tqdm(f):
19
+ lin = line.strip()
20
+ if not lin:
21
+ continue
22
+ content = lin.split('\t')[0]
23
+ for word in tokenizer(content):
24
+ vocab_dic[word] = vocab_dic.get(word, 0) + 1
25
+ vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
26
+ vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
27
+ vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
28
+ return vocab_dic
29
+
30
+ parser = argparse.ArgumentParser(description='Chinese Text Classification')
31
+ parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
32
+ args = parser.parse_args()
33
+ model_name='TextCNN'
34
+ dataset = 'THUCNews' # 数据集
35
+ embedding = 'embedding_SougouNews.npz'
36
+ x = import_module('models.' + model_name)
37
+
38
+ config = x.Config(dataset, embedding)
39
+ device='cuda:0'
40
+ model=models.TextCNN.Model(config)
41
+
42
+ # vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
43
+ model.load_state_dict(torch.load('THUCNews/saved_dict/TextCNN.ckpt'))
44
+ model.to(device)
45
+ model.eval()
46
+
47
+
48
+ tokenizer = lambda x: [y for y in x] # char-level
49
+ if os.path.exists(config.vocab_path):
50
+ vocab = pkl.load(open(config.vocab_path, 'rb'))
51
+ else:
52
+ vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
53
+ pkl.dump(vocab, open(config.vocab_path, 'wb'))
54
+ print(f"Vocab size: {len(vocab)}")
55
+
56
+
57
+ # content='时评:“国学小天才”录取缘何少佳话'
58
+ content=input('输入语句:')
59
+
60
+ words_line = []
61
+ token = tokenizer(content)
62
+ seq_len = len(token)
63
+ pad_size=32
64
+ contents=[]
65
+
66
+ if pad_size:
67
+ if len(token) < pad_size:
68
+ token.extend([PAD] * (pad_size - len(token)))
69
+ else:
70
+ token = token[:pad_size]
71
+ seq_len = pad_size
72
+ # word to id
73
+ for word in token:
74
+ words_line.append(vocab.get(word, vocab.get(UNK)))
75
+
76
+ contents.append((words_line, seq_len))
77
+ print(words_line)
78
+ # input = torch.LongTensor(words_line).unsqueeze(1).to(device) # convert words_line to LongTensor and add batch dimension
79
+ x = torch.LongTensor([_[0] for _ in contents]).to(device)
80
+
81
+ # pad前的长度(超过pad_size的设为pad_size)
82
+ seq_len = torch.LongTensor([_[1] for _ in contents]).to(device)
83
+ input=(x,seq_len)
84
+ print(input)
85
+ with torch.no_grad():
86
+ output = model(input)
87
+ predic = torch.max(output.data, 1)[1].cpu().numpy()
88
+ print(predic)
89
+ print('类别为:{}'.format(classes[predic[0]]))
90
+
91
+
92
+
93
+
94
+
95
+ # with torch.no_grad():
96
+ # output=model(input)
97
+ # print(output)
98
+
99
+ #
100
+ # start_time = time.time()
101
+ # test_iter = build_iterator(test_data, config)
102
+ # with torch.no_grad():
103
+ # predict_all = np.array([], dtype=int)
104
+ # labels_all = np.array([], dtype=int)
105
+ # for texts, labels in test_iter:
106
+ # # texts=texts.to(device)
107
+ # print(texts)
108
+ # outputs = model(texts)
109
+ # loss = F.cross_entropy(outputs, labels)
110
+ # labels = labels.data.cpu().numpy()
111
+ # predic = torch.max(outputs.data, 1)[1].cpu().numpy()
112
+ # labels_all = np.append(labels_all, labels)
113
+ # predict_all = np.append(predict_all, predic)
114
+ # break
115
+ # print(labels_all)
116
+ # print(predict_all)
117
+ #
118
+ #
119
+
120
+
train_eval.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: UTF-8
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from sklearn import metrics
7
+ import time
8
+ from utils import get_time_dif
9
+ from tensorboardX import SummaryWriter
10
+
11
+
12
+ # 权重初始化,默认xavier
13
+ def init_network(model, method='xavier', exclude='embedding', seed=123):
14
+ for name, w in model.named_parameters():
15
+ if exclude not in name:
16
+ if 'weight' in name:
17
+ if method == 'xavier':
18
+ nn.init.xavier_normal_(w)
19
+ elif method == 'kaiming':
20
+ nn.init.kaiming_normal_(w)
21
+ else:
22
+ nn.init.normal_(w)
23
+ elif 'bias' in name:
24
+ nn.init.constant_(w, 0)
25
+ else:
26
+ pass
27
+
28
+
29
+ def train(config, model, train_iter, dev_iter, test_iter):
30
+ start_time = time.time()
31
+ model.train()
32
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
33
+
34
+ # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
35
+ # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
36
+ total_batch = 0 # 记录进行到多少batch
37
+ dev_best_loss = float('inf')
38
+ last_improve = 0 # 记录上次验证集loss下降的batch数
39
+ flag = False # 记录是否很久没有效果提升
40
+ writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
41
+ for epoch in range(config.num_epochs):
42
+ print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
43
+ # scheduler.step() # 学习率衰减
44
+ for i, (trains, labels) in enumerate(train_iter):
45
+ outputs = model(trains)
46
+ model.zero_grad()
47
+ loss = F.cross_entropy(outputs, labels)
48
+ loss.backward()
49
+ optimizer.step()
50
+ if total_batch % 100 == 0:
51
+ # 每多少轮输出在训练集和验证集上的效果
52
+ true = labels.data.cpu()
53
+ predic = torch.max(outputs.data, 1)[1].cpu()
54
+ train_acc = metrics.accuracy_score(true, predic)
55
+ dev_acc, dev_loss = evaluate(config, model, dev_iter)
56
+ if dev_loss < dev_best_loss:
57
+ dev_best_loss = dev_loss
58
+ torch.save(model.state_dict(), config.save_path)
59
+ improve = '*'
60
+ last_improve = total_batch
61
+ else:
62
+ improve = ''
63
+ time_dif = get_time_dif(start_time)
64
+ msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
65
+ print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
66
+ writer.add_scalar("loss/train", loss.item(), total_batch)
67
+ writer.add_scalar("loss/dev", dev_loss, total_batch)
68
+ writer.add_scalar("acc/train", train_acc, total_batch)
69
+ writer.add_scalar("acc/dev", dev_acc, total_batch)
70
+ model.train()
71
+ total_batch += 1
72
+ if total_batch - last_improve > config.require_improvement:
73
+ # 验证集loss超过1000batch没下降,结束训练
74
+ print("No optimization for a long time, auto-stopping...")
75
+ flag = True
76
+ break
77
+ if flag:
78
+ break
79
+ writer.close()
80
+ test(config, model, test_iter)
81
+
82
+
83
+ def test(config, model, test_iter):
84
+ # test
85
+ model.load_state_dict(torch.load(config.save_path))
86
+ model.eval()
87
+ start_time = time.time()
88
+ test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
89
+ msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
90
+ print(msg.format(test_loss, test_acc))
91
+ print("Precision, Recall and F1-Score...")
92
+ print(test_report)
93
+ print("Confusion Matrix...")
94
+ print(test_confusion)
95
+ time_dif = get_time_dif(start_time)
96
+ print("Time usage:", time_dif)
97
+
98
+
99
+ def evaluate(config, model, data_iter, test=False):
100
+ model.eval()
101
+ loss_total = 0
102
+ predict_all = np.array([], dtype=int)
103
+ labels_all = np.array([], dtype=int)
104
+ with torch.no_grad():
105
+ for texts, labels in data_iter:
106
+ outputs = model(texts)
107
+ loss = F.cross_entropy(outputs, labels)
108
+ loss_total += loss
109
+ labels = labels.data.cpu().numpy()
110
+ predic = torch.max(outputs.data, 1)[1].cpu().numpy()
111
+ labels_all = np.append(labels_all, labels)
112
+ predict_all = np.append(predict_all, predic)
113
+
114
+ acc = metrics.accuracy_score(labels_all, predict_all)
115
+ if test:
116
+ report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
117
+ confusion = metrics.confusion_matrix(labels_all, predict_all)
118
+ return acc, loss_total / len(data_iter), report, confusion
119
+ return acc, loss_total / len(data_iter)
utils.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: UTF-8
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import pickle as pkl
6
+ from tqdm import tqdm
7
+ import time
8
+ from datetime import timedelta
9
+
10
+
11
+ MAX_VOCAB_SIZE = 10000 # 词表长度限制
12
+ UNK, PAD = '<UNK>', '<PAD>' # 未知字,padding符号
13
+
14
+
15
+ def build_vocab(file_path, tokenizer, max_size, min_freq):
16
+ vocab_dic = {}
17
+ with open(file_path, 'r', encoding='UTF-8') as f:
18
+ for line in tqdm(f):
19
+ lin = line.strip()
20
+ if not lin:
21
+ continue
22
+ content = lin.split('\t')[0]
23
+ for word in tokenizer(content):
24
+ vocab_dic[word] = vocab_dic.get(word, 0) + 1
25
+ vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
26
+ vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
27
+ vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
28
+ return vocab_dic
29
+
30
+
31
+ def build_dataset(config, ues_word):
32
+ if ues_word:
33
+ tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
34
+ else:
35
+ tokenizer = lambda x: [y for y in x] # char-level
36
+ if os.path.exists(config.vocab_path):
37
+ vocab = pkl.load(open(config.vocab_path, 'rb'))
38
+ else:
39
+ vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
40
+ pkl.dump(vocab, open(config.vocab_path, 'wb'))
41
+ print(f"Vocab size: {len(vocab)}")
42
+
43
+ def load_dataset(path, pad_size=32):
44
+ contents = []
45
+ with open(path, 'r', encoding='UTF-8') as f:
46
+ for line in tqdm(f):
47
+ lin = line.strip()
48
+ if not lin:
49
+ continue
50
+ content, label = lin.split('\t')
51
+ words_line = []
52
+ token = tokenizer(content)
53
+ seq_len = len(token)
54
+ if pad_size:
55
+ if len(token) < pad_size:
56
+ token.extend([PAD] * (pad_size - len(token)))
57
+ else:
58
+ token = token[:pad_size]
59
+ seq_len = pad_size
60
+ # word to id
61
+ for word in token:
62
+ words_line.append(vocab.get(word, vocab.get(UNK)))
63
+ contents.append((words_line, int(label), seq_len))
64
+ return contents # [([...], 0), ([...], 1), ...]
65
+ train = load_dataset(config.train_path, config.pad_size)
66
+ dev = load_dataset(config.dev_path, config.pad_size)
67
+ test = load_dataset(config.test_path, config.pad_size)
68
+ return vocab, train, dev, test
69
+
70
+
71
+ class DatasetIterater(object):
72
+ def __init__(self, batches, batch_size, device):
73
+ self.batch_size = batch_size
74
+ self.batches = batches
75
+ self.n_batches = len(batches) // batch_size
76
+ self.residue = False # 记录batch数量是否为整数
77
+ if len(batches) % self.n_batches != 0:
78
+ self.residue = True
79
+ self.index = 0
80
+ self.device = device
81
+
82
+ def _to_tensor(self, datas):
83
+ x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
84
+ y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
85
+
86
+ # pad前的长度(超过pad_size的设为pad_size)
87
+ seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
88
+ return (x, seq_len), y
89
+
90
+ def __next__(self):
91
+ if self.residue and self.index == self.n_batches:
92
+ batches = self.batches[self.index * self.batch_size: len(self.batches)]
93
+ self.index += 1
94
+
95
+ batches = self._to_tensor(batches)
96
+ return batches
97
+
98
+ elif self.index >= self.n_batches:
99
+ self.index = 0
100
+ raise StopIteration
101
+ else:
102
+ batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
103
+ self.index += 1
104
+ batches = self._to_tensor(batches)
105
+ return batches
106
+
107
+ def __iter__(self):
108
+ return self
109
+
110
+ def __len__(self):
111
+ if self.residue:
112
+ return self.n_batches + 1
113
+ else:
114
+ return self.n_batches
115
+
116
+
117
+ def build_iterator(dataset, config):
118
+ iter = DatasetIterater(dataset, config.batch_size, config.device)
119
+ return iter
120
+
121
+
122
+ def get_time_dif(start_time):
123
+ """获取已使用时间"""
124
+ end_time = time.time()
125
+ time_dif = end_time - start_time
126
+ return timedelta(seconds=int(round(time_dif)))
127
+
128
+
129
+ if __name__ == "__main__":
130
+ '''提取预训练词向量'''
131
+ # 下面的目录、文件名按需更改。
132
+ train_dir = "./THUCNews/data/train.txt"
133
+ vocab_dir = "./THUCNews/data/vocab.pkl"
134
+ pretrain_dir = "./THUCNews/data/sgns.sogou.char"
135
+ emb_dim = 300
136
+ filename_trimmed_dir = "./THUCNews/data/embedding_SougouNews"
137
+ if os.path.exists(vocab_dir):
138
+ word_to_id = pkl.load(open(vocab_dir, 'rb'))
139
+ else:
140
+ # tokenizer = lambda x: x.split(' ') # 以词为单位构建词表(数据集中词之间以空格隔开)
141
+ tokenizer = lambda x: [y for y in x] # 以字为单位构建词表
142
+ word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
143
+ pkl.dump(word_to_id, open(vocab_dir, 'wb'))
144
+
145
+ embeddings = np.random.rand(len(word_to_id), emb_dim)
146
+ f = open(pretrain_dir, "r", encoding='UTF-8')
147
+ for i, line in enumerate(f.readlines()):
148
+ # if i == 0: # 若第一行是标题,则跳过
149
+ # continue
150
+ lin = line.strip().split(" ")
151
+ if lin[0] in word_to_id:
152
+ idx = word_to_id[lin[0]]
153
+ emb = [float(x) for x in lin[1:301]]
154
+ embeddings[idx] = np.asarray(emb, dtype='float32')
155
+ f.close()
156
+ np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)
utils_fasttext.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: UTF-8
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import pickle as pkl
6
+ from tqdm import tqdm
7
+ import time
8
+ from datetime import timedelta
9
+
10
+
11
+ MAX_VOCAB_SIZE = 10000
12
+ UNK, PAD = '<UNK>', '<PAD>'
13
+
14
+
15
+ def build_vocab(file_path, tokenizer, max_size, min_freq):
16
+ vocab_dic = {}
17
+ with open(file_path, 'r', encoding='UTF-8') as f:
18
+ for line in tqdm(f):
19
+ lin = line.strip()
20
+ if not lin:
21
+ continue
22
+ content = lin.split('\t')[0]
23
+ for word in tokenizer(content):
24
+ vocab_dic[word] = vocab_dic.get(word, 0) + 1
25
+ vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
26
+ vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
27
+ vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
28
+ return vocab_dic
29
+
30
+
31
+ def build_dataset(config, ues_word):
32
+ if ues_word:
33
+ tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
34
+ else:
35
+ tokenizer = lambda x: [y for y in x] # char-level
36
+ if os.path.exists(config.vocab_path):
37
+ vocab = pkl.load(open(config.vocab_path, 'rb'))
38
+ else:
39
+ vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
40
+ pkl.dump(vocab, open(config.vocab_path, 'wb'))
41
+ print(f"Vocab size: {len(vocab)}")
42
+
43
+ def biGramHash(sequence, t, buckets):
44
+ t1 = sequence[t - 1] if t - 1 >= 0 else 0
45
+ return (t1 * 14918087) % buckets
46
+
47
+ def triGramHash(sequence, t, buckets):
48
+ t1 = sequence[t - 1] if t - 1 >= 0 else 0
49
+ t2 = sequence[t - 2] if t - 2 >= 0 else 0
50
+ return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
51
+
52
+ def load_dataset(path, pad_size=32):
53
+ contents = []
54
+ with open(path, 'r', encoding='UTF-8') as f:
55
+ for line in tqdm(f):
56
+ lin = line.strip()
57
+ if not lin:
58
+ continue
59
+ content, label = lin.split('\t')
60
+ words_line = []
61
+ token = tokenizer(content)
62
+ seq_len = len(token)
63
+ if pad_size:
64
+ if len(token) < pad_size:
65
+ token.extend([PAD] * (pad_size - len(token)))
66
+ else:
67
+ token = token[:pad_size]
68
+ seq_len = pad_size
69
+ # word to id
70
+ for word in token:
71
+ words_line.append(vocab.get(word, vocab.get(UNK)))
72
+
73
+ # fasttext ngram
74
+ buckets = config.n_gram_vocab
75
+ bigram = []
76
+ trigram = []
77
+ # ------ngram------
78
+ for i in range(pad_size):
79
+ bigram.append(biGramHash(words_line, i, buckets))
80
+ trigram.append(triGramHash(words_line, i, buckets))
81
+ # -----------------
82
+ contents.append((words_line, int(label), seq_len, bigram, trigram))
83
+ return contents # [([...], 0), ([...], 1), ...]
84
+ train = load_dataset(config.train_path, config.pad_size)
85
+ dev = load_dataset(config.dev_path, config.pad_size)
86
+ test = load_dataset(config.test_path, config.pad_size)
87
+ return vocab, train, dev, test
88
+
89
+
90
+ class DatasetIterater(object):
91
+ def __init__(self, batches, batch_size, device):
92
+ self.batch_size = batch_size
93
+ self.batches = batches
94
+ self.n_batches = len(batches) // batch_size
95
+ self.residue = False # 记录batch数量是否为整数
96
+ if len(batches) % self.n_batches != 0:
97
+ self.residue = True
98
+ self.index = 0
99
+ self.device = device
100
+
101
+ def _to_tensor(self, datas):
102
+ # xx = [xxx[2] for xxx in datas]
103
+ # indexx = np.argsort(xx)[::-1]
104
+ # datas = np.array(datas)[indexx]
105
+ x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
106
+ y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
107
+ bigram = torch.LongTensor([_[3] for _ in datas]).to(self.device)
108
+ trigram = torch.LongTensor([_[4] for _ in datas]).to(self.device)
109
+
110
+ # pad前的长度(超过pad_size的设为pad_size)
111
+ seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
112
+ return (x, seq_len, bigram, trigram), y
113
+
114
+ def __next__(self):
115
+ if self.residue and self.index == self.n_batches:
116
+ batches = self.batches[self.index * self.batch_size: len(self.batches)]
117
+ self.index += 1
118
+ batches = self._to_tensor(batches)
119
+ return batches
120
+
121
+ elif self.index >= self.n_batches:
122
+ self.index = 0
123
+ raise StopIteration
124
+ else:
125
+ batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
126
+ self.index += 1
127
+ batches = self._to_tensor(batches)
128
+ return batches
129
+
130
+ def __iter__(self):
131
+ return self
132
+
133
+ def __len__(self):
134
+ if self.residue:
135
+ return self.n_batches + 1
136
+ else:
137
+ return self.n_batches
138
+
139
+
140
+ def build_iterator(dataset, config):
141
+ iter = DatasetIterater(dataset, config.batch_size, config.device)
142
+ return iter
143
+
144
+
145
+ def get_time_dif(start_time):
146
+ """获取已使用时间"""
147
+ end_time = time.time()
148
+ time_dif = end_time - start_time
149
+ return timedelta(seconds=int(round(time_dif)))
150
+
151
+ if __name__ == "__main__":
152
+ '''提取预训练词向量'''
153
+ vocab_dir = "./THUCNews/data/vocab.pkl"
154
+ pretrain_dir = "./THUCNews/data/sgns.sogou.char"
155
+ emb_dim = 300
156
+ filename_trimmed_dir = "./THUCNews/data/vocab.embedding.sougou"
157
+ word_to_id = pkl.load(open(vocab_dir, 'rb'))
158
+ embeddings = np.random.rand(len(word_to_id), emb_dim)
159
+ f = open(pretrain_dir, "r", encoding='UTF-8')
160
+ for i, line in enumerate(f.readlines()):
161
+ # if i == 0: # 若第一行是标题,则跳过
162
+ # continue
163
+ lin = line.strip().split(" ")
164
+ if lin[0] in word_to_id:
165
+ idx = word_to_id[lin[0]]
166
+ emb = [float(x) for x in lin[1:301]]
167
+ embeddings[idx] = np.asarray(emb, dtype='float32')
168
+ f.close()
169
+ np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)