yangfan commited on
Commit
6bce1f7
·
1 Parent(s): f62480b

feat(*): add all for like bert

Browse files
Files changed (11) hide show
  1. .DS_Store +0 -0
  2. config.py +35 -0
  3. data/dev.txt +0 -0
  4. data/input.txt +3 -0
  5. data/label.txt +10 -0
  6. data/test.txt +0 -0
  7. data/train.txt +0 -0
  8. main.py +106 -0
  9. preprocess.py +85 -0
  10. pretrained_bert/README.md +3 -0
  11. train.py +122 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
config.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: UTF-8
2
+
3
+ import os
4
+ import torch
5
+
6
+ class Config(object):
7
+ def __init__(self, data_dir):
8
+ assert os.path.exists(data_dir)
9
+ self.train_file = os.path.join(data_dir, "train.txt")
10
+ self.dev_file = os.path.join(data_dir, "dev.txt")
11
+ self.label_file = os.path.join(data_dir, "label.txt")
12
+ assert os.path.isfile(self.train_file)
13
+ assert os.path.isfile(self.dev_file)
14
+ assert os.path.isfile(self.label_file)
15
+
16
+ self.saved_model_dir = os.path.join(data_dir, "model")
17
+ self.saved_model = os.path.join(self.saved_model_dir, "bert_model.pth")
18
+ if not os.path.exists(self.saved_model_dir):
19
+ os.mkdir(self.saved_model_dir)
20
+
21
+ self.label_list = [label.strip() for label in open(self.label_file, "r", encoding="UTF-8").readlines()]
22
+ self.num_labels = len(self.label_list)
23
+
24
+ self.num_epochs = 3
25
+ self.log_batch = 100
26
+ self.batch_size = 128
27
+ self.max_seq_len = 32
28
+ self.require_improvement = 1000
29
+
30
+ self.warmup_steps = 0
31
+ self.weight_decay = 0.01
32
+ self.max_grad_norm = 1.0
33
+ self.learning_rate = 5e-5
34
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
data/dev.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/input.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 调查显示:29.5%的人不满意当年所选高考专业
2
+ 广汽今日整体上市 最大短板在过度依赖丰田本田
3
+ 梦游之王再现湖人大滑坡 金州小快枪刷分气懵科比
data/label.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ finance
2
+ realty
3
+ stocks
4
+ education
5
+ science
6
+ society
7
+ politics
8
+ sports
9
+ game
10
+ entertainment
data/test.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/train.txt ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: UTF-8
2
+
3
+ import os
4
+ import time
5
+ import torch
6
+ import argparse
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from train import train
10
+ from config import Config
11
+ from preprocess import DataProcessor, get_time_dif
12
+ from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
13
+
14
+ parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
15
+ parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
16
+ parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
17
+ parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
18
+ parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
19
+ parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
20
+ args = parser.parse_args()
21
+
22
+ def set_seed(seed):
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+ torch.cuda.manual_seed_all(seed)
26
+ torch.backends.cudnn.deterministic = True
27
+
28
+ def main():
29
+ set_seed(args.seed)
30
+ config = Config(args.data_dir)
31
+
32
+ tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
33
+ bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
34
+ model = BertForSequenceClassification.from_pretrained(
35
+ os.path.join(args.pretrained_bert_dir, "pytorch_model.bin"),
36
+ config=bert_config
37
+ )
38
+ model.to(config.device)
39
+
40
+ if args.mode == "train":
41
+ print("loading data...")
42
+ start_time = time.time()
43
+ train_iterator = DataProcessor(config.train_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
44
+ dev_iterator = DataProcessor(config.dev_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
45
+ time_dif = get_time_dif(start_time)
46
+ print("time usage:", time_dif)
47
+
48
+ # train
49
+ train(model, config, train_iterator, dev_iterator)
50
+
51
+ elif args.mode == "demo":
52
+ model.load_state_dict(torch.load(config.saved_model))
53
+ model.eval()
54
+ while True:
55
+ sentence = input("请输入文本:\n")
56
+ inputs = tokenizer(
57
+ sentence,
58
+ max_length=config.max_seq_len,
59
+ truncation="longest_first",
60
+ return_tensors="pt")
61
+ inputs = inputs.to(config.device)
62
+ with torch.no_grad():
63
+ outputs = model(**inputs)
64
+ logits = outputs[0]
65
+ label = torch.max(logits.data, 1)[1].tolist()
66
+ print("分类结果:" + config.label_list[label[0]])
67
+ flag = str(input("continue? (y/n):"))
68
+ if flag == "Y" or flag == "y":
69
+ continue
70
+ else:
71
+ break
72
+ else:
73
+ model.load_state_dict(torch.load(config.saved_model))
74
+ model.eval()
75
+
76
+ text = []
77
+ with open(args.input_file, mode="r", encoding="UTF-8") as f:
78
+ for line in tqdm(f):
79
+ sentence = line.strip()
80
+ if not sentence: continue
81
+ text.append(sentence)
82
+
83
+ num_samples = len(text)
84
+ num_batches = (num_samples - 1) // config.batch_size + 1
85
+ for i in range(num_batches):
86
+ start = i * config.batch_size
87
+ end = min(num_samples, (i + 1) * config.batch_size)
88
+ inputs = tokenizer.batch_encode_plus(
89
+ text[start: end],
90
+ padding=True,
91
+ max_length=config.max_seq_len,
92
+ truncation="longest_first",
93
+ return_tensors="pt")
94
+ inputs = inputs.to(config.device)
95
+
96
+ outputs = model(**inputs)
97
+ logits = outputs[0]
98
+
99
+ preds = torch.max(logits.data, 1)[1].tolist()
100
+ labels = [config.label_list[_] for _ in preds]
101
+ for j in range(start, end):
102
+ print("%s\t%s" % (text[j], labels[j - start]))
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
preprocess.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: UTF-8
2
+
3
+ import time
4
+ import torch
5
+ import random
6
+ from tqdm import tqdm
7
+ from datetime import timedelta
8
+
9
+ def get_time_dif(start_time):
10
+ end_time = time.time()
11
+ time_dif = end_time - start_time
12
+ return timedelta(seconds=int(round(time_dif)))
13
+
14
+ class DataProcessor(object):
15
+ def __init__(self, path, device, tokenizer, batch_size, max_seq_len, seed):
16
+ self.seed = seed
17
+ self.device = device
18
+ self.tokenizer = tokenizer
19
+ self.batch_size = batch_size
20
+ self.max_seq_len = max_seq_len
21
+
22
+ self.data = self.load(path)
23
+
24
+ self.index = 0
25
+ self.residue = False
26
+ self.num_samples = len(self.data[0])
27
+ self.num_batches = self.num_samples // self.batch_size
28
+ if self.num_samples % self.batch_size != 0:
29
+ self.residue = True
30
+
31
+ def load(self, path):
32
+ contents = []
33
+ labels = []
34
+ with open(path, mode="r", encoding="UTF-8") as f:
35
+ for line in tqdm(f):
36
+ line = line.strip()
37
+ if not line: continue
38
+ if line.find('\t') == -1: continue
39
+ content, label = line.split("\t")
40
+ contents.append(content)
41
+ labels.append(int(label))
42
+ #random shuffle
43
+ index = list(range(len(labels)))
44
+ random.seed(self.seed)
45
+ random.shuffle(index)
46
+ contents = [contents[_] for _ in index]
47
+ labels = [labels[_] for _ in index]
48
+ return (contents, labels)
49
+
50
+ def __next__(self):
51
+ if self.residue and self.index == self.num_batches:
52
+ batch_x = self.data[0][self.index * self.batch_size: self.num_samples]
53
+ batch_y = self.data[1][self.index * self.batch_size: self.num_samples]
54
+ batch = self._to_tensor(batch_x, batch_y)
55
+ self.index += 1
56
+ return batch
57
+ elif self.index >= self.num_batches:
58
+ self.index = 0
59
+ raise StopIteration
60
+ else:
61
+ batch_x = self.data[0][self.index * self.batch_size: (self.index + 1) * self.batch_size]
62
+ batch_y = self.data[1][self.index * self.batch_size: (self.index + 1) * self.batch_size]
63
+ batch = self._to_tensor(batch_x, batch_y)
64
+ self.index += 1
65
+ return batch
66
+
67
+ def _to_tensor(self, batch_x, batch_y):
68
+ inputs = self.tokenizer.batch_encode_plus(
69
+ batch_x,
70
+ padding="max_length",
71
+ max_length=self.max_seq_len,
72
+ truncation="longest_first",
73
+ return_tensors="pt")
74
+ inputs = inputs.to(self.device)
75
+ labels = torch.LongTensor(batch_y).to(self.device)
76
+ return (inputs, labels)
77
+
78
+ def __iter__(self):
79
+ return self
80
+
81
+ def __len__(self):
82
+ if self.residue:
83
+ return self.num_batches + 1
84
+ else:
85
+ return self.num_batches
pretrained_bert/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 在 pretrained_bert 文件夹中放入 huggingface 的 bert-base-chinese 模型权重 pytorch_model.bin、配置文件 config.json 和词典 vocab.txt
2
+
3
+ huggingface bert-base-chinese 下载地址:https://huggingface.co/bert-base-chinese/tree/main
train.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: UTF-8
2
+
3
+ from typing import Iterator
4
+ from transformers import AdamW, get_linear_schedule_with_warmup
5
+ from preprocess import get_time_dif
6
+ from sklearn import metrics
7
+ import time
8
+ import torch
9
+ import numpy as np
10
+
11
+ def eval(model, config, iterator, flag=False):
12
+ model.eval()
13
+
14
+ total_loss = 0
15
+ all_preds = np.array([], dtype=int)
16
+ all_labels = np.array([], dtype=int)
17
+ with torch.no_grad():
18
+ for batch, labels in iterator:
19
+ outputs = model(
20
+ input_ids=batch["input_ids"],
21
+ attention_mask=batch["attention_mask"],
22
+ token_type_ids=batch["token_type_ids"],
23
+ labels=labels)
24
+
25
+ loss = outputs[0]
26
+ logits = outputs[1]
27
+
28
+ total_loss += loss
29
+ true = labels.data.cpu().numpy()
30
+ pred = torch.max(logits.data, 1)[1].cpu().numpy()
31
+ all_labels = np.append(all_labels, true)
32
+ all_preds = np.append(all_preds, pred)
33
+
34
+ acc = metrics.accuracy_score(all_labels, all_preds)
35
+ if flag:
36
+ report = metrics.classification_report(all_labels, all_preds, target_names=config.label_list, digits=4)
37
+ confusion = metrics.confusion_matrix(all_labels, all_preds)
38
+ return acc, total_loss / len(iterator), report, confusion
39
+ return acc, total_loss / len(iterator)
40
+
41
+
42
+ def test(model, config, iterator):
43
+ model.load_state_dict(torch.load(config.saved_model))
44
+ start_time = time.time()
45
+ acc, loss, report, confusion = eval(model, config, iterator, flag=True)
46
+ msg = "Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}"
47
+ print(msg.format(loss, acc))
48
+ print("Precision, Recall and F1-Score...")
49
+ print(report)
50
+ print("Confusion Matrix...")
51
+ print(confusion)
52
+ time_dif = get_time_dif(start_time)
53
+ print("Time usage:", time_dif)
54
+
55
+
56
+ def train(model, config, train_iterator, dev_iterator):
57
+ model.train()
58
+ start_time = time.time()
59
+
60
+ no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
61
+ param_optimizer = model.named_parameters()
62
+ optimizer_grouped_parameters = [
63
+ {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay},
64
+ {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
65
+ ]
66
+
67
+ t_total = len(train_iterator) * config.num_epochs
68
+ optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)
69
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=t_total)
70
+
71
+ total_batch = 0
72
+ last_improve = 0
73
+ break_flag = False
74
+ best_dev_loss = float('inf')
75
+ for epoch in range(config.num_epochs):
76
+ print("Epoch [{}/{}]".format(epoch + 1, config.num_epochs))
77
+ for _, (batch, labels) in enumerate(train_iterator):
78
+
79
+ outputs = model(
80
+ input_ids=batch["input_ids"],
81
+ attention_mask=batch["attention_mask"],
82
+ token_type_ids=batch["token_type_ids"],
83
+ labels=labels)
84
+
85
+ loss = outputs[0]
86
+ logits = outputs[1]
87
+
88
+ loss.backward()
89
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
90
+
91
+ optimizer.step()
92
+ scheduler.step()
93
+ optimizer.zero_grad()
94
+
95
+ if total_batch % config.log_batch == 0:
96
+ true = labels.data.cpu()
97
+ pred = torch.max(logits.data, 1)[1].cpu()
98
+ acc = metrics.accuracy_score(true, pred)
99
+ dev_acc, dev_loss = eval(model, config, dev_iterator)
100
+ if dev_loss < best_dev_loss:
101
+ best_dev_loss = dev_loss
102
+ torch.save(model.state_dict(), config.saved_model)
103
+ improve = "*"
104
+ last_improve = total_batch
105
+ else:
106
+ improve = ""
107
+
108
+ time_dif = get_time_dif(start_time)
109
+ msg = 'Iter: {0:>6}, Batch Train Loss: {1:>5.2}, Batch Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
110
+ print(msg.format(total_batch, loss.item(), acc, dev_loss, dev_acc, time_dif, improve))
111
+ model.train()
112
+
113
+ total_batch += 1
114
+ if total_batch - last_improve > config.require_improvement:
115
+ print("No improvement for a long time, auto-stopping...")
116
+ break_flag = True
117
+ break
118
+ if break_flag:
119
+ break
120
+
121
+ test(model, config, dev_iterator)
122
+