|
--- |
|
license: mit |
|
language: |
|
- zh |
|
pipeline_tag: text-classification |
|
--- |
|
# Bert Chinese Text Classification Model |
|
this a Bert Model that train for customer service of logistics companies |
|
### data(with noise since it from ASR text) |
|
train: 10878 rows |
|
dev:2720 rows |
|
total: 13598 rows |
|
### param |
|
embed_dim: 128 |
|
batch size: 64 |
|
contextsize: 20 |
|
n_head: 2 |
|
epoches: 100 |
|
|
|
## Word Label(word, index, number of occurences) |
|
```sh |
|
我 1 18719 |
|
个 2 12236 |
|
快 3 8152 |
|
一 4 8097 |
|
递 5 7295 |
|
那 6 7118 |
|
了 7 6923 |
|
的 8 6684 |
|
是 9 6632 |
|
到 10 6434 |
|
你 11 5144 |
|
没 12 4989 |
|
有 13 4664 |
|
下 14 4433 |
|
这 15 4219 |
|
在 16 4219 |
|
么 17 4010 |
|
查 18 3964 |
|
就 19 3570 |
|
好 20 3524 |
|
``` |
|
|
|
## Tokenizer |
|
```python |
|
label_dict, label_n2w = read_labelFile(labelFile) |
|
word2ind, ind2word = get_worddict(wordLabelFile) |
|
stoplist = read_stopword(stopwordFile) |
|
cla_dict = {} |
|
|
|
# train data to vec |
|
traindataTxt = open(trainDataVecFile, 'w') |
|
datas = open(trainFile, 'r', encoding='utf_8').readlines() |
|
datas = list(filter(None, datas)) |
|
random.shuffle(datas) |
|
for line in tqdm(datas, desc="traindata to vec"): |
|
line = line.replace('\n', '').split(':') |
|
# line = line.replace('\n','').split('\t') |
|
cla = line[1] |
|
# if cla in [21, 13, 9, 24, 23, 19, 14]: |
|
# continue |
|
if cla in cla_dict: |
|
cla_dict[cla] += 1 |
|
else: |
|
cla_dict[cla] = 1 |
|
|
|
cla_ind = label_dict[cla] |
|
title_seg = ['我', '要', '下', '单'] |
|
title_seg = [i for i in line[0]] |
|
# title_seg = jieba.cut(line[0], cut_all=False) |
|
title_ind = [cla_ind] |
|
for w in title_seg: |
|
if w in stoplist: |
|
continue |
|
title_ind.append(word2ind[w]) |
|
length = len(title_ind) |
|
if length > maxLen + 1: |
|
title_ind = title_ind[0:21] |
|
if length < maxLen + 1: |
|
title_ind.extend([0] * (maxLen - length + 1)) |
|
|
|
for n in title_ind: |
|
traindataTxt.write(str(n) + ',') |
|
traindataTxt.write('\n') |
|
``` |
|
|
|
## Trainer |
|
```python |
|
# set the seed for ensuring reproducibility |
|
seed = 3407 |
|
# init net |
|
print('init net...') |
|
model = my_model() |
|
model.to(device) |
|
print(model) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
print("training...") |
|
|
|
best_dev_acc = 0 |
|
# embed.train() |
|
for epoch in range(100): |
|
model.train() |
|
for i, (clas, sentences) in enumerate(train_dataLoader): |
|
# sentences: batch size 64 x sentence length 20 x embed dimension 128 |
|
# 一个字是个128维vector 一句话是个 20x128的2D tensor 一个batch有64句话是个 64x20x128的3D tensor |
|
|
|
out = model(sentences.to( |
|
device)) # out: batch size 64 x word vector 4 (after my_linear) |
|
try: |
|
loss = criterion(out, clas.to(device)) |
|
except: |
|
print(out.size(), out) |
|
print(clas.size(), clas) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
if (i + 1) % 10 == 0: |
|
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item()) |
|
model.eval() |
|
dev_acc = validation(model=model, val_dataLoader=val_dataLoader, |
|
device=device) |
|
|
|
if best_dev_acc < dev_acc: |
|
best_dev_acc = dev_acc |
|
print("save model...") |
|
torch.save(model.state_dict(), "model.bin") |
|
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item()) |
|
print("best dev acc %.4f dev acc %.4f" % (best_dev_acc, dev_acc)) |
|
``` |
|
|
|
## Testing |
|
```python |
|
def validation(model, val_dataLoader, device): |
|
model.eval() |
|
total = 0 |
|
correct = 0 |
|
with torch.no_grad(): |
|
for i, (clas, sentences) in enumerate(val_dataLoader): |
|
try: |
|
# sentences = sentences.type(torch.LongTensor).to(device) |
|
# clas = clas.type(torch.LongTensor).to(device) |
|
out = model( |
|
sentences.to( |
|
device)) # out: batch size 64 x sentences length 20 x word dimension 4(after my_linear) |
|
# out = F.relu(out.squeeze(-3)) |
|
# out = F.max_pool1d(out, out.size(2)).squeeze(2) |
|
# softmax = nn.Softmax(dim=1) |
|
|
|
pred = torch.argmax(out, dim=1) # 64x4 -> 64x1 |
|
|
|
correct += (pred == clas.to(device)).sum() |
|
total += clas.size()[0] |
|
except IndexError as e: |
|
print(i) |
|
print('clas', clas) |
|
print('clas size', clas.size()) |
|
print('sentence', sentences) |
|
print('sentences size', sentences.size()) |
|
print(e) |
|
print(e.__traceback__) |
|
exit() |
|
|
|
acc = correct / total |
|
return acc |
|
``` |