yangfan
commited on
Commit
·
6bce1f7
1
Parent(s):
f62480b
feat(*): add all for like bert
Browse files- .DS_Store +0 -0
- config.py +35 -0
- data/dev.txt +0 -0
- data/input.txt +3 -0
- data/label.txt +10 -0
- data/test.txt +0 -0
- data/train.txt +0 -0
- main.py +106 -0
- preprocess.py +85 -0
- pretrained_bert/README.md +3 -0
- train.py +122 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
config.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: UTF-8
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class Config(object):
|
7 |
+
def __init__(self, data_dir):
|
8 |
+
assert os.path.exists(data_dir)
|
9 |
+
self.train_file = os.path.join(data_dir, "train.txt")
|
10 |
+
self.dev_file = os.path.join(data_dir, "dev.txt")
|
11 |
+
self.label_file = os.path.join(data_dir, "label.txt")
|
12 |
+
assert os.path.isfile(self.train_file)
|
13 |
+
assert os.path.isfile(self.dev_file)
|
14 |
+
assert os.path.isfile(self.label_file)
|
15 |
+
|
16 |
+
self.saved_model_dir = os.path.join(data_dir, "model")
|
17 |
+
self.saved_model = os.path.join(self.saved_model_dir, "bert_model.pth")
|
18 |
+
if not os.path.exists(self.saved_model_dir):
|
19 |
+
os.mkdir(self.saved_model_dir)
|
20 |
+
|
21 |
+
self.label_list = [label.strip() for label in open(self.label_file, "r", encoding="UTF-8").readlines()]
|
22 |
+
self.num_labels = len(self.label_list)
|
23 |
+
|
24 |
+
self.num_epochs = 3
|
25 |
+
self.log_batch = 100
|
26 |
+
self.batch_size = 128
|
27 |
+
self.max_seq_len = 32
|
28 |
+
self.require_improvement = 1000
|
29 |
+
|
30 |
+
self.warmup_steps = 0
|
31 |
+
self.weight_decay = 0.01
|
32 |
+
self.max_grad_norm = 1.0
|
33 |
+
self.learning_rate = 5e-5
|
34 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
|
data/dev.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/input.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
调查显示:29.5%的人不满意当年所选高考专业
|
2 |
+
广汽今日整体上市 最大短板在过度依赖丰田本田
|
3 |
+
梦游之王再现湖人大滑坡 金州小快枪刷分气懵科比
|
data/label.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
finance
|
2 |
+
realty
|
3 |
+
stocks
|
4 |
+
education
|
5 |
+
science
|
6 |
+
society
|
7 |
+
politics
|
8 |
+
sports
|
9 |
+
game
|
10 |
+
entertainment
|
data/test.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/train.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
main.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: UTF-8
|
2 |
+
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
from train import train
|
10 |
+
from config import Config
|
11 |
+
from preprocess import DataProcessor, get_time_dif
|
12 |
+
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
|
15 |
+
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
|
16 |
+
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
|
17 |
+
parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
|
18 |
+
parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
|
19 |
+
parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
def set_seed(seed):
|
23 |
+
np.random.seed(seed)
|
24 |
+
torch.manual_seed(seed)
|
25 |
+
torch.cuda.manual_seed_all(seed)
|
26 |
+
torch.backends.cudnn.deterministic = True
|
27 |
+
|
28 |
+
def main():
|
29 |
+
set_seed(args.seed)
|
30 |
+
config = Config(args.data_dir)
|
31 |
+
|
32 |
+
tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
|
33 |
+
bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
|
34 |
+
model = BertForSequenceClassification.from_pretrained(
|
35 |
+
os.path.join(args.pretrained_bert_dir, "pytorch_model.bin"),
|
36 |
+
config=bert_config
|
37 |
+
)
|
38 |
+
model.to(config.device)
|
39 |
+
|
40 |
+
if args.mode == "train":
|
41 |
+
print("loading data...")
|
42 |
+
start_time = time.time()
|
43 |
+
train_iterator = DataProcessor(config.train_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
|
44 |
+
dev_iterator = DataProcessor(config.dev_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
|
45 |
+
time_dif = get_time_dif(start_time)
|
46 |
+
print("time usage:", time_dif)
|
47 |
+
|
48 |
+
# train
|
49 |
+
train(model, config, train_iterator, dev_iterator)
|
50 |
+
|
51 |
+
elif args.mode == "demo":
|
52 |
+
model.load_state_dict(torch.load(config.saved_model))
|
53 |
+
model.eval()
|
54 |
+
while True:
|
55 |
+
sentence = input("请输入文本:\n")
|
56 |
+
inputs = tokenizer(
|
57 |
+
sentence,
|
58 |
+
max_length=config.max_seq_len,
|
59 |
+
truncation="longest_first",
|
60 |
+
return_tensors="pt")
|
61 |
+
inputs = inputs.to(config.device)
|
62 |
+
with torch.no_grad():
|
63 |
+
outputs = model(**inputs)
|
64 |
+
logits = outputs[0]
|
65 |
+
label = torch.max(logits.data, 1)[1].tolist()
|
66 |
+
print("分类结果:" + config.label_list[label[0]])
|
67 |
+
flag = str(input("continue? (y/n):"))
|
68 |
+
if flag == "Y" or flag == "y":
|
69 |
+
continue
|
70 |
+
else:
|
71 |
+
break
|
72 |
+
else:
|
73 |
+
model.load_state_dict(torch.load(config.saved_model))
|
74 |
+
model.eval()
|
75 |
+
|
76 |
+
text = []
|
77 |
+
with open(args.input_file, mode="r", encoding="UTF-8") as f:
|
78 |
+
for line in tqdm(f):
|
79 |
+
sentence = line.strip()
|
80 |
+
if not sentence: continue
|
81 |
+
text.append(sentence)
|
82 |
+
|
83 |
+
num_samples = len(text)
|
84 |
+
num_batches = (num_samples - 1) // config.batch_size + 1
|
85 |
+
for i in range(num_batches):
|
86 |
+
start = i * config.batch_size
|
87 |
+
end = min(num_samples, (i + 1) * config.batch_size)
|
88 |
+
inputs = tokenizer.batch_encode_plus(
|
89 |
+
text[start: end],
|
90 |
+
padding=True,
|
91 |
+
max_length=config.max_seq_len,
|
92 |
+
truncation="longest_first",
|
93 |
+
return_tensors="pt")
|
94 |
+
inputs = inputs.to(config.device)
|
95 |
+
|
96 |
+
outputs = model(**inputs)
|
97 |
+
logits = outputs[0]
|
98 |
+
|
99 |
+
preds = torch.max(logits.data, 1)[1].tolist()
|
100 |
+
labels = [config.label_list[_] for _ in preds]
|
101 |
+
for j in range(start, end):
|
102 |
+
print("%s\t%s" % (text[j], labels[j - start]))
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
main()
|
preprocess.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: UTF-8
|
2 |
+
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from tqdm import tqdm
|
7 |
+
from datetime import timedelta
|
8 |
+
|
9 |
+
def get_time_dif(start_time):
|
10 |
+
end_time = time.time()
|
11 |
+
time_dif = end_time - start_time
|
12 |
+
return timedelta(seconds=int(round(time_dif)))
|
13 |
+
|
14 |
+
class DataProcessor(object):
|
15 |
+
def __init__(self, path, device, tokenizer, batch_size, max_seq_len, seed):
|
16 |
+
self.seed = seed
|
17 |
+
self.device = device
|
18 |
+
self.tokenizer = tokenizer
|
19 |
+
self.batch_size = batch_size
|
20 |
+
self.max_seq_len = max_seq_len
|
21 |
+
|
22 |
+
self.data = self.load(path)
|
23 |
+
|
24 |
+
self.index = 0
|
25 |
+
self.residue = False
|
26 |
+
self.num_samples = len(self.data[0])
|
27 |
+
self.num_batches = self.num_samples // self.batch_size
|
28 |
+
if self.num_samples % self.batch_size != 0:
|
29 |
+
self.residue = True
|
30 |
+
|
31 |
+
def load(self, path):
|
32 |
+
contents = []
|
33 |
+
labels = []
|
34 |
+
with open(path, mode="r", encoding="UTF-8") as f:
|
35 |
+
for line in tqdm(f):
|
36 |
+
line = line.strip()
|
37 |
+
if not line: continue
|
38 |
+
if line.find('\t') == -1: continue
|
39 |
+
content, label = line.split("\t")
|
40 |
+
contents.append(content)
|
41 |
+
labels.append(int(label))
|
42 |
+
#random shuffle
|
43 |
+
index = list(range(len(labels)))
|
44 |
+
random.seed(self.seed)
|
45 |
+
random.shuffle(index)
|
46 |
+
contents = [contents[_] for _ in index]
|
47 |
+
labels = [labels[_] for _ in index]
|
48 |
+
return (contents, labels)
|
49 |
+
|
50 |
+
def __next__(self):
|
51 |
+
if self.residue and self.index == self.num_batches:
|
52 |
+
batch_x = self.data[0][self.index * self.batch_size: self.num_samples]
|
53 |
+
batch_y = self.data[1][self.index * self.batch_size: self.num_samples]
|
54 |
+
batch = self._to_tensor(batch_x, batch_y)
|
55 |
+
self.index += 1
|
56 |
+
return batch
|
57 |
+
elif self.index >= self.num_batches:
|
58 |
+
self.index = 0
|
59 |
+
raise StopIteration
|
60 |
+
else:
|
61 |
+
batch_x = self.data[0][self.index * self.batch_size: (self.index + 1) * self.batch_size]
|
62 |
+
batch_y = self.data[1][self.index * self.batch_size: (self.index + 1) * self.batch_size]
|
63 |
+
batch = self._to_tensor(batch_x, batch_y)
|
64 |
+
self.index += 1
|
65 |
+
return batch
|
66 |
+
|
67 |
+
def _to_tensor(self, batch_x, batch_y):
|
68 |
+
inputs = self.tokenizer.batch_encode_plus(
|
69 |
+
batch_x,
|
70 |
+
padding="max_length",
|
71 |
+
max_length=self.max_seq_len,
|
72 |
+
truncation="longest_first",
|
73 |
+
return_tensors="pt")
|
74 |
+
inputs = inputs.to(self.device)
|
75 |
+
labels = torch.LongTensor(batch_y).to(self.device)
|
76 |
+
return (inputs, labels)
|
77 |
+
|
78 |
+
def __iter__(self):
|
79 |
+
return self
|
80 |
+
|
81 |
+
def __len__(self):
|
82 |
+
if self.residue:
|
83 |
+
return self.num_batches + 1
|
84 |
+
else:
|
85 |
+
return self.num_batches
|
pretrained_bert/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
在 pretrained_bert 文件夹中放入 huggingface 的 bert-base-chinese 模型权重 pytorch_model.bin、配置文件 config.json 和词典 vocab.txt
|
2 |
+
|
3 |
+
huggingface bert-base-chinese 下载地址:https://huggingface.co/bert-base-chinese/tree/main
|
train.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: UTF-8
|
2 |
+
|
3 |
+
from typing import Iterator
|
4 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
5 |
+
from preprocess import get_time_dif
|
6 |
+
from sklearn import metrics
|
7 |
+
import time
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
def eval(model, config, iterator, flag=False):
|
12 |
+
model.eval()
|
13 |
+
|
14 |
+
total_loss = 0
|
15 |
+
all_preds = np.array([], dtype=int)
|
16 |
+
all_labels = np.array([], dtype=int)
|
17 |
+
with torch.no_grad():
|
18 |
+
for batch, labels in iterator:
|
19 |
+
outputs = model(
|
20 |
+
input_ids=batch["input_ids"],
|
21 |
+
attention_mask=batch["attention_mask"],
|
22 |
+
token_type_ids=batch["token_type_ids"],
|
23 |
+
labels=labels)
|
24 |
+
|
25 |
+
loss = outputs[0]
|
26 |
+
logits = outputs[1]
|
27 |
+
|
28 |
+
total_loss += loss
|
29 |
+
true = labels.data.cpu().numpy()
|
30 |
+
pred = torch.max(logits.data, 1)[1].cpu().numpy()
|
31 |
+
all_labels = np.append(all_labels, true)
|
32 |
+
all_preds = np.append(all_preds, pred)
|
33 |
+
|
34 |
+
acc = metrics.accuracy_score(all_labels, all_preds)
|
35 |
+
if flag:
|
36 |
+
report = metrics.classification_report(all_labels, all_preds, target_names=config.label_list, digits=4)
|
37 |
+
confusion = metrics.confusion_matrix(all_labels, all_preds)
|
38 |
+
return acc, total_loss / len(iterator), report, confusion
|
39 |
+
return acc, total_loss / len(iterator)
|
40 |
+
|
41 |
+
|
42 |
+
def test(model, config, iterator):
|
43 |
+
model.load_state_dict(torch.load(config.saved_model))
|
44 |
+
start_time = time.time()
|
45 |
+
acc, loss, report, confusion = eval(model, config, iterator, flag=True)
|
46 |
+
msg = "Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}"
|
47 |
+
print(msg.format(loss, acc))
|
48 |
+
print("Precision, Recall and F1-Score...")
|
49 |
+
print(report)
|
50 |
+
print("Confusion Matrix...")
|
51 |
+
print(confusion)
|
52 |
+
time_dif = get_time_dif(start_time)
|
53 |
+
print("Time usage:", time_dif)
|
54 |
+
|
55 |
+
|
56 |
+
def train(model, config, train_iterator, dev_iterator):
|
57 |
+
model.train()
|
58 |
+
start_time = time.time()
|
59 |
+
|
60 |
+
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
|
61 |
+
param_optimizer = model.named_parameters()
|
62 |
+
optimizer_grouped_parameters = [
|
63 |
+
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay},
|
64 |
+
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
65 |
+
]
|
66 |
+
|
67 |
+
t_total = len(train_iterator) * config.num_epochs
|
68 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)
|
69 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=t_total)
|
70 |
+
|
71 |
+
total_batch = 0
|
72 |
+
last_improve = 0
|
73 |
+
break_flag = False
|
74 |
+
best_dev_loss = float('inf')
|
75 |
+
for epoch in range(config.num_epochs):
|
76 |
+
print("Epoch [{}/{}]".format(epoch + 1, config.num_epochs))
|
77 |
+
for _, (batch, labels) in enumerate(train_iterator):
|
78 |
+
|
79 |
+
outputs = model(
|
80 |
+
input_ids=batch["input_ids"],
|
81 |
+
attention_mask=batch["attention_mask"],
|
82 |
+
token_type_ids=batch["token_type_ids"],
|
83 |
+
labels=labels)
|
84 |
+
|
85 |
+
loss = outputs[0]
|
86 |
+
logits = outputs[1]
|
87 |
+
|
88 |
+
loss.backward()
|
89 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
90 |
+
|
91 |
+
optimizer.step()
|
92 |
+
scheduler.step()
|
93 |
+
optimizer.zero_grad()
|
94 |
+
|
95 |
+
if total_batch % config.log_batch == 0:
|
96 |
+
true = labels.data.cpu()
|
97 |
+
pred = torch.max(logits.data, 1)[1].cpu()
|
98 |
+
acc = metrics.accuracy_score(true, pred)
|
99 |
+
dev_acc, dev_loss = eval(model, config, dev_iterator)
|
100 |
+
if dev_loss < best_dev_loss:
|
101 |
+
best_dev_loss = dev_loss
|
102 |
+
torch.save(model.state_dict(), config.saved_model)
|
103 |
+
improve = "*"
|
104 |
+
last_improve = total_batch
|
105 |
+
else:
|
106 |
+
improve = ""
|
107 |
+
|
108 |
+
time_dif = get_time_dif(start_time)
|
109 |
+
msg = 'Iter: {0:>6}, Batch Train Loss: {1:>5.2}, Batch Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
|
110 |
+
print(msg.format(total_batch, loss.item(), acc, dev_loss, dev_acc, time_dif, improve))
|
111 |
+
model.train()
|
112 |
+
|
113 |
+
total_batch += 1
|
114 |
+
if total_batch - last_improve > config.require_improvement:
|
115 |
+
print("No improvement for a long time, auto-stopping...")
|
116 |
+
break_flag = True
|
117 |
+
break
|
118 |
+
if break_flag:
|
119 |
+
break
|
120 |
+
|
121 |
+
test(model, config, dev_iterator)
|
122 |
+
|