Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- run.py +52 -0
- test.py +120 -0
- train_eval.py +119 -0
- utils.py +156 -0
- utils_fasttext.py +169 -0
run.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: UTF-8
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from train_eval import train, init_network
|
6 |
+
from importlib import import_module
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser(description='Chinese Text Classification')
|
10 |
+
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
|
11 |
+
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
|
12 |
+
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
|
13 |
+
args = parser.parse_args()
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == '__main__':
|
17 |
+
dataset = 'THUCNews' # 数据集
|
18 |
+
|
19 |
+
# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
|
20 |
+
embedding = 'embedding_SougouNews.npz'
|
21 |
+
if args.embedding == 'random':
|
22 |
+
embedding = 'random'
|
23 |
+
model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer
|
24 |
+
if model_name == 'FastText':
|
25 |
+
from utils_fasttext import build_dataset, build_iterator, get_time_dif
|
26 |
+
embedding = 'random'
|
27 |
+
else:
|
28 |
+
from utils import build_dataset, build_iterator, get_time_dif
|
29 |
+
|
30 |
+
x = import_module('models.' + model_name)
|
31 |
+
config = x.Config(dataset, embedding)
|
32 |
+
np.random.seed(1)
|
33 |
+
torch.manual_seed(1)
|
34 |
+
torch.cuda.manual_seed_all(1)
|
35 |
+
torch.backends.cudnn.deterministic = True # 保证每次结果一样
|
36 |
+
|
37 |
+
start_time = time.time()
|
38 |
+
print("Loading data...")
|
39 |
+
vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
|
40 |
+
train_iter = build_iterator(train_data, config)
|
41 |
+
dev_iter = build_iterator(dev_data, config)
|
42 |
+
test_iter = build_iterator(test_data, config)
|
43 |
+
time_dif = get_time_dif(start_time)
|
44 |
+
print("Time usage:", time_dif)
|
45 |
+
|
46 |
+
# train
|
47 |
+
config.n_vocab = len(vocab)
|
48 |
+
model = x.Model(config).to(config.device)
|
49 |
+
if model_name != 'Transformer':
|
50 |
+
init_network(model)
|
51 |
+
print(model.parameters)
|
52 |
+
train(config, model, train_iter, dev_iter, test_iter)
|
test.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from importlib import import_module
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from tqdm import tqdm
|
7 |
+
import models.TextCNN
|
8 |
+
import torch
|
9 |
+
import pickle as pkl
|
10 |
+
from utils import build_dataset
|
11 |
+
classes=['finance','realty','stocks','education','science','society','politics','sports','game','entertainment']
|
12 |
+
|
13 |
+
MAX_VOCAB_SIZE = 10000 # 词表长度限制
|
14 |
+
UNK, PAD = '<UNK>', '<PAD>' # 未知字,padding符号
|
15 |
+
def build_vocab(file_path, tokenizer, max_size, min_freq):
|
16 |
+
vocab_dic = {}
|
17 |
+
with open(file_path, 'r', encoding='UTF-8') as f:
|
18 |
+
for line in tqdm(f):
|
19 |
+
lin = line.strip()
|
20 |
+
if not lin:
|
21 |
+
continue
|
22 |
+
content = lin.split('\t')[0]
|
23 |
+
for word in tokenizer(content):
|
24 |
+
vocab_dic[word] = vocab_dic.get(word, 0) + 1
|
25 |
+
vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
|
26 |
+
vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
|
27 |
+
vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
|
28 |
+
return vocab_dic
|
29 |
+
|
30 |
+
parser = argparse.ArgumentParser(description='Chinese Text Classification')
|
31 |
+
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
|
32 |
+
args = parser.parse_args()
|
33 |
+
model_name='TextCNN'
|
34 |
+
dataset = 'THUCNews' # 数据集
|
35 |
+
embedding = 'embedding_SougouNews.npz'
|
36 |
+
x = import_module('models.' + model_name)
|
37 |
+
|
38 |
+
config = x.Config(dataset, embedding)
|
39 |
+
device='cuda:0'
|
40 |
+
model=models.TextCNN.Model(config)
|
41 |
+
|
42 |
+
# vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
|
43 |
+
model.load_state_dict(torch.load('THUCNews/saved_dict/TextCNN.ckpt'))
|
44 |
+
model.to(device)
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
|
48 |
+
tokenizer = lambda x: [y for y in x] # char-level
|
49 |
+
if os.path.exists(config.vocab_path):
|
50 |
+
vocab = pkl.load(open(config.vocab_path, 'rb'))
|
51 |
+
else:
|
52 |
+
vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
|
53 |
+
pkl.dump(vocab, open(config.vocab_path, 'wb'))
|
54 |
+
print(f"Vocab size: {len(vocab)}")
|
55 |
+
|
56 |
+
|
57 |
+
# content='时评:“国学小天才”录取缘何少佳话'
|
58 |
+
content=input('输入语句:')
|
59 |
+
|
60 |
+
words_line = []
|
61 |
+
token = tokenizer(content)
|
62 |
+
seq_len = len(token)
|
63 |
+
pad_size=32
|
64 |
+
contents=[]
|
65 |
+
|
66 |
+
if pad_size:
|
67 |
+
if len(token) < pad_size:
|
68 |
+
token.extend([PAD] * (pad_size - len(token)))
|
69 |
+
else:
|
70 |
+
token = token[:pad_size]
|
71 |
+
seq_len = pad_size
|
72 |
+
# word to id
|
73 |
+
for word in token:
|
74 |
+
words_line.append(vocab.get(word, vocab.get(UNK)))
|
75 |
+
|
76 |
+
contents.append((words_line, seq_len))
|
77 |
+
print(words_line)
|
78 |
+
# input = torch.LongTensor(words_line).unsqueeze(1).to(device) # convert words_line to LongTensor and add batch dimension
|
79 |
+
x = torch.LongTensor([_[0] for _ in contents]).to(device)
|
80 |
+
|
81 |
+
# pad前的长度(超过pad_size的设为pad_size)
|
82 |
+
seq_len = torch.LongTensor([_[1] for _ in contents]).to(device)
|
83 |
+
input=(x,seq_len)
|
84 |
+
print(input)
|
85 |
+
with torch.no_grad():
|
86 |
+
output = model(input)
|
87 |
+
predic = torch.max(output.data, 1)[1].cpu().numpy()
|
88 |
+
print(predic)
|
89 |
+
print('类别为:{}'.format(classes[predic[0]]))
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
# with torch.no_grad():
|
96 |
+
# output=model(input)
|
97 |
+
# print(output)
|
98 |
+
|
99 |
+
#
|
100 |
+
# start_time = time.time()
|
101 |
+
# test_iter = build_iterator(test_data, config)
|
102 |
+
# with torch.no_grad():
|
103 |
+
# predict_all = np.array([], dtype=int)
|
104 |
+
# labels_all = np.array([], dtype=int)
|
105 |
+
# for texts, labels in test_iter:
|
106 |
+
# # texts=texts.to(device)
|
107 |
+
# print(texts)
|
108 |
+
# outputs = model(texts)
|
109 |
+
# loss = F.cross_entropy(outputs, labels)
|
110 |
+
# labels = labels.data.cpu().numpy()
|
111 |
+
# predic = torch.max(outputs.data, 1)[1].cpu().numpy()
|
112 |
+
# labels_all = np.append(labels_all, labels)
|
113 |
+
# predict_all = np.append(predict_all, predic)
|
114 |
+
# break
|
115 |
+
# print(labels_all)
|
116 |
+
# print(predict_all)
|
117 |
+
#
|
118 |
+
#
|
119 |
+
|
120 |
+
|
train_eval.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: UTF-8
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from sklearn import metrics
|
7 |
+
import time
|
8 |
+
from utils import get_time_dif
|
9 |
+
from tensorboardX import SummaryWriter
|
10 |
+
|
11 |
+
|
12 |
+
# 权重初始化,默认xavier
|
13 |
+
def init_network(model, method='xavier', exclude='embedding', seed=123):
|
14 |
+
for name, w in model.named_parameters():
|
15 |
+
if exclude not in name:
|
16 |
+
if 'weight' in name:
|
17 |
+
if method == 'xavier':
|
18 |
+
nn.init.xavier_normal_(w)
|
19 |
+
elif method == 'kaiming':
|
20 |
+
nn.init.kaiming_normal_(w)
|
21 |
+
else:
|
22 |
+
nn.init.normal_(w)
|
23 |
+
elif 'bias' in name:
|
24 |
+
nn.init.constant_(w, 0)
|
25 |
+
else:
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
def train(config, model, train_iter, dev_iter, test_iter):
|
30 |
+
start_time = time.time()
|
31 |
+
model.train()
|
32 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
|
33 |
+
|
34 |
+
# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
|
35 |
+
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
|
36 |
+
total_batch = 0 # 记录进行到多少batch
|
37 |
+
dev_best_loss = float('inf')
|
38 |
+
last_improve = 0 # 记录上次验证集loss下降的batch数
|
39 |
+
flag = False # 记录是否很久没有效果提升
|
40 |
+
writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
|
41 |
+
for epoch in range(config.num_epochs):
|
42 |
+
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
|
43 |
+
# scheduler.step() # 学习率衰减
|
44 |
+
for i, (trains, labels) in enumerate(train_iter):
|
45 |
+
outputs = model(trains)
|
46 |
+
model.zero_grad()
|
47 |
+
loss = F.cross_entropy(outputs, labels)
|
48 |
+
loss.backward()
|
49 |
+
optimizer.step()
|
50 |
+
if total_batch % 100 == 0:
|
51 |
+
# 每多少轮输出在训练集和验证集上的效果
|
52 |
+
true = labels.data.cpu()
|
53 |
+
predic = torch.max(outputs.data, 1)[1].cpu()
|
54 |
+
train_acc = metrics.accuracy_score(true, predic)
|
55 |
+
dev_acc, dev_loss = evaluate(config, model, dev_iter)
|
56 |
+
if dev_loss < dev_best_loss:
|
57 |
+
dev_best_loss = dev_loss
|
58 |
+
torch.save(model.state_dict(), config.save_path)
|
59 |
+
improve = '*'
|
60 |
+
last_improve = total_batch
|
61 |
+
else:
|
62 |
+
improve = ''
|
63 |
+
time_dif = get_time_dif(start_time)
|
64 |
+
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
|
65 |
+
print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
|
66 |
+
writer.add_scalar("loss/train", loss.item(), total_batch)
|
67 |
+
writer.add_scalar("loss/dev", dev_loss, total_batch)
|
68 |
+
writer.add_scalar("acc/train", train_acc, total_batch)
|
69 |
+
writer.add_scalar("acc/dev", dev_acc, total_batch)
|
70 |
+
model.train()
|
71 |
+
total_batch += 1
|
72 |
+
if total_batch - last_improve > config.require_improvement:
|
73 |
+
# 验证集loss超过1000batch没下降,结束训练
|
74 |
+
print("No optimization for a long time, auto-stopping...")
|
75 |
+
flag = True
|
76 |
+
break
|
77 |
+
if flag:
|
78 |
+
break
|
79 |
+
writer.close()
|
80 |
+
test(config, model, test_iter)
|
81 |
+
|
82 |
+
|
83 |
+
def test(config, model, test_iter):
|
84 |
+
# test
|
85 |
+
model.load_state_dict(torch.load(config.save_path))
|
86 |
+
model.eval()
|
87 |
+
start_time = time.time()
|
88 |
+
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
|
89 |
+
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
|
90 |
+
print(msg.format(test_loss, test_acc))
|
91 |
+
print("Precision, Recall and F1-Score...")
|
92 |
+
print(test_report)
|
93 |
+
print("Confusion Matrix...")
|
94 |
+
print(test_confusion)
|
95 |
+
time_dif = get_time_dif(start_time)
|
96 |
+
print("Time usage:", time_dif)
|
97 |
+
|
98 |
+
|
99 |
+
def evaluate(config, model, data_iter, test=False):
|
100 |
+
model.eval()
|
101 |
+
loss_total = 0
|
102 |
+
predict_all = np.array([], dtype=int)
|
103 |
+
labels_all = np.array([], dtype=int)
|
104 |
+
with torch.no_grad():
|
105 |
+
for texts, labels in data_iter:
|
106 |
+
outputs = model(texts)
|
107 |
+
loss = F.cross_entropy(outputs, labels)
|
108 |
+
loss_total += loss
|
109 |
+
labels = labels.data.cpu().numpy()
|
110 |
+
predic = torch.max(outputs.data, 1)[1].cpu().numpy()
|
111 |
+
labels_all = np.append(labels_all, labels)
|
112 |
+
predict_all = np.append(predict_all, predic)
|
113 |
+
|
114 |
+
acc = metrics.accuracy_score(labels_all, predict_all)
|
115 |
+
if test:
|
116 |
+
report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
|
117 |
+
confusion = metrics.confusion_matrix(labels_all, predict_all)
|
118 |
+
return acc, loss_total / len(data_iter), report, confusion
|
119 |
+
return acc, loss_total / len(data_iter)
|
utils.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: UTF-8
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pickle as pkl
|
6 |
+
from tqdm import tqdm
|
7 |
+
import time
|
8 |
+
from datetime import timedelta
|
9 |
+
|
10 |
+
|
11 |
+
MAX_VOCAB_SIZE = 10000 # 词表长度限制
|
12 |
+
UNK, PAD = '<UNK>', '<PAD>' # 未知字,padding符号
|
13 |
+
|
14 |
+
|
15 |
+
def build_vocab(file_path, tokenizer, max_size, min_freq):
|
16 |
+
vocab_dic = {}
|
17 |
+
with open(file_path, 'r', encoding='UTF-8') as f:
|
18 |
+
for line in tqdm(f):
|
19 |
+
lin = line.strip()
|
20 |
+
if not lin:
|
21 |
+
continue
|
22 |
+
content = lin.split('\t')[0]
|
23 |
+
for word in tokenizer(content):
|
24 |
+
vocab_dic[word] = vocab_dic.get(word, 0) + 1
|
25 |
+
vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
|
26 |
+
vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
|
27 |
+
vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
|
28 |
+
return vocab_dic
|
29 |
+
|
30 |
+
|
31 |
+
def build_dataset(config, ues_word):
|
32 |
+
if ues_word:
|
33 |
+
tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
|
34 |
+
else:
|
35 |
+
tokenizer = lambda x: [y for y in x] # char-level
|
36 |
+
if os.path.exists(config.vocab_path):
|
37 |
+
vocab = pkl.load(open(config.vocab_path, 'rb'))
|
38 |
+
else:
|
39 |
+
vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
|
40 |
+
pkl.dump(vocab, open(config.vocab_path, 'wb'))
|
41 |
+
print(f"Vocab size: {len(vocab)}")
|
42 |
+
|
43 |
+
def load_dataset(path, pad_size=32):
|
44 |
+
contents = []
|
45 |
+
with open(path, 'r', encoding='UTF-8') as f:
|
46 |
+
for line in tqdm(f):
|
47 |
+
lin = line.strip()
|
48 |
+
if not lin:
|
49 |
+
continue
|
50 |
+
content, label = lin.split('\t')
|
51 |
+
words_line = []
|
52 |
+
token = tokenizer(content)
|
53 |
+
seq_len = len(token)
|
54 |
+
if pad_size:
|
55 |
+
if len(token) < pad_size:
|
56 |
+
token.extend([PAD] * (pad_size - len(token)))
|
57 |
+
else:
|
58 |
+
token = token[:pad_size]
|
59 |
+
seq_len = pad_size
|
60 |
+
# word to id
|
61 |
+
for word in token:
|
62 |
+
words_line.append(vocab.get(word, vocab.get(UNK)))
|
63 |
+
contents.append((words_line, int(label), seq_len))
|
64 |
+
return contents # [([...], 0), ([...], 1), ...]
|
65 |
+
train = load_dataset(config.train_path, config.pad_size)
|
66 |
+
dev = load_dataset(config.dev_path, config.pad_size)
|
67 |
+
test = load_dataset(config.test_path, config.pad_size)
|
68 |
+
return vocab, train, dev, test
|
69 |
+
|
70 |
+
|
71 |
+
class DatasetIterater(object):
|
72 |
+
def __init__(self, batches, batch_size, device):
|
73 |
+
self.batch_size = batch_size
|
74 |
+
self.batches = batches
|
75 |
+
self.n_batches = len(batches) // batch_size
|
76 |
+
self.residue = False # 记录batch数量是否为整数
|
77 |
+
if len(batches) % self.n_batches != 0:
|
78 |
+
self.residue = True
|
79 |
+
self.index = 0
|
80 |
+
self.device = device
|
81 |
+
|
82 |
+
def _to_tensor(self, datas):
|
83 |
+
x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
|
84 |
+
y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
|
85 |
+
|
86 |
+
# pad前的长度(超过pad_size的设为pad_size)
|
87 |
+
seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
|
88 |
+
return (x, seq_len), y
|
89 |
+
|
90 |
+
def __next__(self):
|
91 |
+
if self.residue and self.index == self.n_batches:
|
92 |
+
batches = self.batches[self.index * self.batch_size: len(self.batches)]
|
93 |
+
self.index += 1
|
94 |
+
|
95 |
+
batches = self._to_tensor(batches)
|
96 |
+
return batches
|
97 |
+
|
98 |
+
elif self.index >= self.n_batches:
|
99 |
+
self.index = 0
|
100 |
+
raise StopIteration
|
101 |
+
else:
|
102 |
+
batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
|
103 |
+
self.index += 1
|
104 |
+
batches = self._to_tensor(batches)
|
105 |
+
return batches
|
106 |
+
|
107 |
+
def __iter__(self):
|
108 |
+
return self
|
109 |
+
|
110 |
+
def __len__(self):
|
111 |
+
if self.residue:
|
112 |
+
return self.n_batches + 1
|
113 |
+
else:
|
114 |
+
return self.n_batches
|
115 |
+
|
116 |
+
|
117 |
+
def build_iterator(dataset, config):
|
118 |
+
iter = DatasetIterater(dataset, config.batch_size, config.device)
|
119 |
+
return iter
|
120 |
+
|
121 |
+
|
122 |
+
def get_time_dif(start_time):
|
123 |
+
"""获取已使用时间"""
|
124 |
+
end_time = time.time()
|
125 |
+
time_dif = end_time - start_time
|
126 |
+
return timedelta(seconds=int(round(time_dif)))
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
'''提取预训练词向量'''
|
131 |
+
# 下面的目录、文件名按需更改。
|
132 |
+
train_dir = "./THUCNews/data/train.txt"
|
133 |
+
vocab_dir = "./THUCNews/data/vocab.pkl"
|
134 |
+
pretrain_dir = "./THUCNews/data/sgns.sogou.char"
|
135 |
+
emb_dim = 300
|
136 |
+
filename_trimmed_dir = "./THUCNews/data/embedding_SougouNews"
|
137 |
+
if os.path.exists(vocab_dir):
|
138 |
+
word_to_id = pkl.load(open(vocab_dir, 'rb'))
|
139 |
+
else:
|
140 |
+
# tokenizer = lambda x: x.split(' ') # 以词为单位构建词表(数据集中词之间以空格隔开)
|
141 |
+
tokenizer = lambda x: [y for y in x] # 以字为单位构建词表
|
142 |
+
word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
|
143 |
+
pkl.dump(word_to_id, open(vocab_dir, 'wb'))
|
144 |
+
|
145 |
+
embeddings = np.random.rand(len(word_to_id), emb_dim)
|
146 |
+
f = open(pretrain_dir, "r", encoding='UTF-8')
|
147 |
+
for i, line in enumerate(f.readlines()):
|
148 |
+
# if i == 0: # 若第一行是标题,则跳过
|
149 |
+
# continue
|
150 |
+
lin = line.strip().split(" ")
|
151 |
+
if lin[0] in word_to_id:
|
152 |
+
idx = word_to_id[lin[0]]
|
153 |
+
emb = [float(x) for x in lin[1:301]]
|
154 |
+
embeddings[idx] = np.asarray(emb, dtype='float32')
|
155 |
+
f.close()
|
156 |
+
np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)
|
utils_fasttext.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: UTF-8
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pickle as pkl
|
6 |
+
from tqdm import tqdm
|
7 |
+
import time
|
8 |
+
from datetime import timedelta
|
9 |
+
|
10 |
+
|
11 |
+
MAX_VOCAB_SIZE = 10000
|
12 |
+
UNK, PAD = '<UNK>', '<PAD>'
|
13 |
+
|
14 |
+
|
15 |
+
def build_vocab(file_path, tokenizer, max_size, min_freq):
|
16 |
+
vocab_dic = {}
|
17 |
+
with open(file_path, 'r', encoding='UTF-8') as f:
|
18 |
+
for line in tqdm(f):
|
19 |
+
lin = line.strip()
|
20 |
+
if not lin:
|
21 |
+
continue
|
22 |
+
content = lin.split('\t')[0]
|
23 |
+
for word in tokenizer(content):
|
24 |
+
vocab_dic[word] = vocab_dic.get(word, 0) + 1
|
25 |
+
vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
|
26 |
+
vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
|
27 |
+
vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
|
28 |
+
return vocab_dic
|
29 |
+
|
30 |
+
|
31 |
+
def build_dataset(config, ues_word):
|
32 |
+
if ues_word:
|
33 |
+
tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
|
34 |
+
else:
|
35 |
+
tokenizer = lambda x: [y for y in x] # char-level
|
36 |
+
if os.path.exists(config.vocab_path):
|
37 |
+
vocab = pkl.load(open(config.vocab_path, 'rb'))
|
38 |
+
else:
|
39 |
+
vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
|
40 |
+
pkl.dump(vocab, open(config.vocab_path, 'wb'))
|
41 |
+
print(f"Vocab size: {len(vocab)}")
|
42 |
+
|
43 |
+
def biGramHash(sequence, t, buckets):
|
44 |
+
t1 = sequence[t - 1] if t - 1 >= 0 else 0
|
45 |
+
return (t1 * 14918087) % buckets
|
46 |
+
|
47 |
+
def triGramHash(sequence, t, buckets):
|
48 |
+
t1 = sequence[t - 1] if t - 1 >= 0 else 0
|
49 |
+
t2 = sequence[t - 2] if t - 2 >= 0 else 0
|
50 |
+
return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
|
51 |
+
|
52 |
+
def load_dataset(path, pad_size=32):
|
53 |
+
contents = []
|
54 |
+
with open(path, 'r', encoding='UTF-8') as f:
|
55 |
+
for line in tqdm(f):
|
56 |
+
lin = line.strip()
|
57 |
+
if not lin:
|
58 |
+
continue
|
59 |
+
content, label = lin.split('\t')
|
60 |
+
words_line = []
|
61 |
+
token = tokenizer(content)
|
62 |
+
seq_len = len(token)
|
63 |
+
if pad_size:
|
64 |
+
if len(token) < pad_size:
|
65 |
+
token.extend([PAD] * (pad_size - len(token)))
|
66 |
+
else:
|
67 |
+
token = token[:pad_size]
|
68 |
+
seq_len = pad_size
|
69 |
+
# word to id
|
70 |
+
for word in token:
|
71 |
+
words_line.append(vocab.get(word, vocab.get(UNK)))
|
72 |
+
|
73 |
+
# fasttext ngram
|
74 |
+
buckets = config.n_gram_vocab
|
75 |
+
bigram = []
|
76 |
+
trigram = []
|
77 |
+
# ------ngram------
|
78 |
+
for i in range(pad_size):
|
79 |
+
bigram.append(biGramHash(words_line, i, buckets))
|
80 |
+
trigram.append(triGramHash(words_line, i, buckets))
|
81 |
+
# -----------------
|
82 |
+
contents.append((words_line, int(label), seq_len, bigram, trigram))
|
83 |
+
return contents # [([...], 0), ([...], 1), ...]
|
84 |
+
train = load_dataset(config.train_path, config.pad_size)
|
85 |
+
dev = load_dataset(config.dev_path, config.pad_size)
|
86 |
+
test = load_dataset(config.test_path, config.pad_size)
|
87 |
+
return vocab, train, dev, test
|
88 |
+
|
89 |
+
|
90 |
+
class DatasetIterater(object):
|
91 |
+
def __init__(self, batches, batch_size, device):
|
92 |
+
self.batch_size = batch_size
|
93 |
+
self.batches = batches
|
94 |
+
self.n_batches = len(batches) // batch_size
|
95 |
+
self.residue = False # 记录batch数量是否为整数
|
96 |
+
if len(batches) % self.n_batches != 0:
|
97 |
+
self.residue = True
|
98 |
+
self.index = 0
|
99 |
+
self.device = device
|
100 |
+
|
101 |
+
def _to_tensor(self, datas):
|
102 |
+
# xx = [xxx[2] for xxx in datas]
|
103 |
+
# indexx = np.argsort(xx)[::-1]
|
104 |
+
# datas = np.array(datas)[indexx]
|
105 |
+
x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
|
106 |
+
y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
|
107 |
+
bigram = torch.LongTensor([_[3] for _ in datas]).to(self.device)
|
108 |
+
trigram = torch.LongTensor([_[4] for _ in datas]).to(self.device)
|
109 |
+
|
110 |
+
# pad前的长度(超过pad_size的设为pad_size)
|
111 |
+
seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
|
112 |
+
return (x, seq_len, bigram, trigram), y
|
113 |
+
|
114 |
+
def __next__(self):
|
115 |
+
if self.residue and self.index == self.n_batches:
|
116 |
+
batches = self.batches[self.index * self.batch_size: len(self.batches)]
|
117 |
+
self.index += 1
|
118 |
+
batches = self._to_tensor(batches)
|
119 |
+
return batches
|
120 |
+
|
121 |
+
elif self.index >= self.n_batches:
|
122 |
+
self.index = 0
|
123 |
+
raise StopIteration
|
124 |
+
else:
|
125 |
+
batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
|
126 |
+
self.index += 1
|
127 |
+
batches = self._to_tensor(batches)
|
128 |
+
return batches
|
129 |
+
|
130 |
+
def __iter__(self):
|
131 |
+
return self
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
if self.residue:
|
135 |
+
return self.n_batches + 1
|
136 |
+
else:
|
137 |
+
return self.n_batches
|
138 |
+
|
139 |
+
|
140 |
+
def build_iterator(dataset, config):
|
141 |
+
iter = DatasetIterater(dataset, config.batch_size, config.device)
|
142 |
+
return iter
|
143 |
+
|
144 |
+
|
145 |
+
def get_time_dif(start_time):
|
146 |
+
"""获取已使用时间"""
|
147 |
+
end_time = time.time()
|
148 |
+
time_dif = end_time - start_time
|
149 |
+
return timedelta(seconds=int(round(time_dif)))
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
'''提取预训练词向量'''
|
153 |
+
vocab_dir = "./THUCNews/data/vocab.pkl"
|
154 |
+
pretrain_dir = "./THUCNews/data/sgns.sogou.char"
|
155 |
+
emb_dim = 300
|
156 |
+
filename_trimmed_dir = "./THUCNews/data/vocab.embedding.sougou"
|
157 |
+
word_to_id = pkl.load(open(vocab_dir, 'rb'))
|
158 |
+
embeddings = np.random.rand(len(word_to_id), emb_dim)
|
159 |
+
f = open(pretrain_dir, "r", encoding='UTF-8')
|
160 |
+
for i, line in enumerate(f.readlines()):
|
161 |
+
# if i == 0: # 若第一行是标题,则跳过
|
162 |
+
# continue
|
163 |
+
lin = line.strip().split(" ")
|
164 |
+
if lin[0] in word_to_id:
|
165 |
+
idx = word_to_id[lin[0]]
|
166 |
+
emb = [float(x) for x in lin[1:301]]
|
167 |
+
embeddings[idx] = np.asarray(emb, dtype='float32')
|
168 |
+
f.close()
|
169 |
+
np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)
|