--- |
license: mit |
language: |
- zh |
pipeline_tag: text-classification |
--- |
# Bert Chinese Text Classification Model |
this a Bert Model that train for customer service of logistics companies |
### data(with noise since it from ASR text) |
train: 10878 rows |
dev:2720 rows |
total: 13598 rows |
### param |
embed_dim: 128 |
batch size: 64 |
contextsize: 20 |
n_head: 2 |
epoches: 100 |
## Word Label(word, index, number of occurences) |
```sh |
我 1 18719 |
个 2 12236 |
快 3 8152 |
一 4 8097 |
递 5 7295 |
那 6 7118 |
了 7 6923 |
的 8 6684 |
是 9 6632 |
到 10 6434 |
你 11 5144 |
没 12 4989 |
有 13 4664 |
下 14 4433 |
这 15 4219 |
在 16 4219 |
么 17 4010 |
查 18 3964 |
就 19 3570 |
好 20 3524 |
``` |
## Tokenizer |
```python |
label_dict, label_n2w = read_labelFile(labelFile) |
word2ind, ind2word = get_worddict(wordLabelFile) |
stoplist = read_stopword(stopwordFile) |
cla_dict = {} |
# train data to vec |
traindataTxt = open(trainDataVecFile, 'w') |
datas = open(trainFile, 'r', encoding='utf_8').readlines() |
datas = list(filter(None, datas)) |
random.shuffle(datas) |
for line in tqdm(datas, desc="traindata to vec"): |
line = line.replace('\n', '').split(':') |
# line = line.replace('\n','').split('\t') |
cla = line[1] |
# if cla in [21, 13, 9, 24, 23, 19, 14]: |
# continue |
if cla in cla_dict: |
cla_dict[cla] += 1 |
else: |
cla_dict[cla] = 1 |
cla_ind = label_dict[cla] |
title_seg = ['我', '要', '下', '单'] |
title_seg = [i for i in line[0]] |
# title_seg = jieba.cut(line[0], cut_all=False) |
title_ind = [cla_ind] |
for w in title_seg: |
if w in stoplist: |
continue |
title_ind.append(word2ind[w]) |
length = len(title_ind) |
if length > maxLen + 1: |
title_ind = title_ind[0:21] |
if length < maxLen + 1: |
title_ind.extend([0] * (maxLen - length + 1)) |
for n in title_ind: |
traindataTxt.write(str(n) + ',') |
traindataTxt.write('\n') |
``` |
## Trainer |
```python |
# set the seed for ensuring reproducibility |
seed = 3407 |
# init net |
print('init net...') |
model = my_model() |
model.to(device) |
print(model) |
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) |
criterion = nn.CrossEntropyLoss() |
print("training...") |
best_dev_acc = 0 |
# embed.train() |
for epoch in range(100): |
model.train() |
for i, (clas, sentences) in enumerate(train_dataLoader): |
# sentences: batch size 64 x sentence length 20 x embed dimension 128 |
# 一个字是个128维vector 一句话是个 20x128的2D tensor 一个batch有64句话是个 64x20x128的3D tensor |
out = model(sentences.to( |
device)) # out: batch size 64 x word vector 4 (after my_linear) |
try: |
loss = criterion(out, clas.to(device)) |
except: |
print(out.size(), out) |
print(clas.size(), clas) |
optimizer.zero_grad() |
loss.backward() |
optimizer.step() |
if (i + 1) % 10 == 0: |
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item()) |
model.eval() |
dev_acc = validation(model=model, val_dataLoader=val_dataLoader, |
device=device) |
if best_dev_acc < dev_acc: |
best_dev_acc = dev_acc |
print("save model...") |
torch.save(model.state_dict(), "model.bin") |
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item()) |
print("best dev acc %.4f dev acc %.4f" % (best_dev_acc, dev_acc)) |
``` |
## Testing |
```python |
def validation(model, val_dataLoader, device): |
model.eval() |
total = 0 |
correct = 0 |
with torch.no_grad(): |
for i, (clas, sentences) in enumerate(val_dataLoader): |
try: |
# sentences = sentences.type(torch.LongTensor).to(device) |
# clas = clas.type(torch.LongTensor).to(device) |
out = model( |
sentences.to( |
device)) # out: batch size 64 x sentences length 20 x word dimension 4(after my_linear) |
# out = F.relu(out.squeeze(-3)) |
# out = F.max_pool1d(out, out.size(2)).squeeze(2) |
# softmax = nn.Softmax(dim=1) |
pred = torch.argmax(out, dim=1) # 64x4 -> 64x1 |
correct += (pred == clas.to(device)).sum() |
total += clas.size()[0] |
except IndexError as e: |
print(i) |
print('clas', clas) |
print('clas size', clas.size()) |
print('sentence', sentences) |
print('sentences size', sentences.size()) |
print(e) |
print(e.__traceback__) |
exit() |
acc = correct / total |
return acc |
``` |