Upload 21 files
Browse files- .gitattributes +1 -0
- macbert/checkpoints/saved_checkpoints/name_checkpoint15_train8000.pth.tar +3 -0
- macbert/checkpoints/saved_checkpoints/name_checkpoint17_train9000.pth.tar +3 -0
- macbert/checkpoints/saved_checkpoints/travel_checkpoint15_train8000.pth.tar +3 -0
- macbert/checkpoints/saved_checkpoints/travel_checkpoint_17_train9000.pth.tar +3 -0
- macbert/dataset.py +89 -0
- macbert/dataset/datagame_sms_stage1(in).csv +3 -0
- macbert/dataset/name_test_9000.csv +0 -0
- macbert/dataset/name_train_9000.csv +0 -0
- macbert/dataset/name_val_9000.csv +0 -0
- macbert/dataset/travel_test_9000.csv +0 -0
- macbert/dataset/travel_train_9000.csv +0 -0
- macbert/dataset/travel_val_9000.csv +0 -0
- macbert/infer.py +138 -0
- macbert/infer_all.py +167 -0
- macbert/infer_name.py +144 -0
- macbert/infer_travel.py +144 -0
- macbert/main.py +285 -0
- macbert/models.py +42 -0
- macbert/requirements.txt +5 -0
- macbert/test.py +6 -0
- macbert/utils.py +70 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
macbert/dataset/datagame_sms_stage1(in).csv filter=lfs diff=lfs merge=lfs -text
|
macbert/checkpoints/saved_checkpoints/name_checkpoint15_train8000.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a9c216795384fe75c774e640eb2929f46b8dbc409ffd9d961e74627d6224204
|
| 3 |
+
size 1222778219
|
macbert/checkpoints/saved_checkpoints/name_checkpoint17_train9000.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e1abe8fd9b5099381277eb432db9c95f7467bb362bf7244a979bcdd691492770
|
| 3 |
+
size 1222778219
|
macbert/checkpoints/saved_checkpoints/travel_checkpoint15_train8000.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:716b58feddc2bcaae611858581b72725e1592aca1d0645dc2abf38b9f499e791
|
| 3 |
+
size 1222778219
|
macbert/checkpoints/saved_checkpoints/travel_checkpoint_17_train9000.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:027e43ef9081d786a967ffa95ca95a52a1487c8a76f2f6b4409e796dab657b53
|
| 3 |
+
size 1222778219
|
macbert/dataset.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from torchvision.transforms.functional import to_tensor, to_pil_image
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
# from transformers import AutoModel
|
| 8 |
+
# from transformers import AutoTokenizer, AutoConfig
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.autograd import Variable
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 14 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import random
|
| 18 |
+
import os
|
| 19 |
+
import copy
|
| 20 |
+
import pandas as pd
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MyDataset(Dataset):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
ann_file,
|
| 27 |
+
cfg,
|
| 28 |
+
mode='tra',
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
data = np.array(pd.read_csv(ann_file))
|
| 33 |
+
self.data = data
|
| 34 |
+
self.mode = mode
|
| 35 |
+
self.cfg = cfg
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, index):
|
| 38 |
+
if self.mode == 'test':
|
| 39 |
+
d = self.data[index]
|
| 40 |
+
context = d[1]
|
| 41 |
+
sms_id = d[0]
|
| 42 |
+
return context, sms_id
|
| 43 |
+
else :
|
| 44 |
+
d = self.data[index]
|
| 45 |
+
context = d[1]
|
| 46 |
+
label = d[2]
|
| 47 |
+
label = int(label)
|
| 48 |
+
return context, label
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
return len(self.data)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == '__main__':
|
| 55 |
+
d = 'C:/Users/u/Desktop/workspace/bs/myr/data/val.csv'
|
| 56 |
+
D = MyDataset(d, cfg={})
|
| 57 |
+
nb_1 = 0
|
| 58 |
+
for i, d in enumerate(D):
|
| 59 |
+
_, l = d
|
| 60 |
+
if l==1:
|
| 61 |
+
nb_1 += 1
|
| 62 |
+
print(nb_1/len(D))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
np.random.seed(666)
|
| 66 |
+
|
| 67 |
+
ann_file1 = 'C:/Users/u/Desktop/workspace/bs/myr/data/test_samples.csv'
|
| 68 |
+
ann_file2 = 'C:/Users/u/Desktop/workspace/bs/myr/data/train_samples.csv'
|
| 69 |
+
data1 = pd.read_csv(ann_file1)
|
| 70 |
+
data2 = pd.read_csv(ann_file2)
|
| 71 |
+
data = pd.concat([data1, data2])
|
| 72 |
+
|
| 73 |
+
data = np.array(data)
|
| 74 |
+
np.random.shuffle(data)
|
| 75 |
+
|
| 76 |
+
data_tra = data[:int(len(data)*0.7)]
|
| 77 |
+
data_val = data[int(len(data)*0.7):]
|
| 78 |
+
|
| 79 |
+
data_tra = pd.DataFrame(data_tra, columns=['content', 'label'])
|
| 80 |
+
data_val = pd.DataFrame(data_val, columns=['content', 'label'])
|
| 81 |
+
|
| 82 |
+
data_tra.to_csv('C:/Users/u/Desktop/workspace/bs/myr/data/tra.csv', index=False)
|
| 83 |
+
data_val.to_csv('C:/Users/u/Desktop/workspace/bs/myr/data/val.csv', index=False)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
macbert/dataset/datagame_sms_stage1(in).csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f83bdff3455421e19a75bb1fe947b752751046189fec28a2f865bdef32ae2e9
|
| 3 |
+
size 47370252
|
macbert/dataset/name_test_9000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
macbert/dataset/name_train_9000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
macbert/dataset/name_val_9000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
macbert/dataset/travel_test_9000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
macbert/dataset/travel_train_9000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
macbert/dataset/travel_val_9000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
macbert/infer.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from torchvision.transforms.functional import to_tensor, to_pil_image
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from transformers import AutoModel
|
| 11 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import random
|
| 21 |
+
import numpy as np
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from rich import print
|
| 24 |
+
import time
|
| 25 |
+
import cv2
|
| 26 |
+
from glob import glob
|
| 27 |
+
import string
|
| 28 |
+
from torch.optim import AdamW
|
| 29 |
+
from transformers import get_linear_schedule_with_warmup
|
| 30 |
+
from models import get_model
|
| 31 |
+
from dataset import MyDataset
|
| 32 |
+
from utils import save_checkpoint, AverageMeter, ProgressMeter
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_epoch(model, epoch, dataloader, tokenizer):
|
| 36 |
+
print(f"\n\n=> val")
|
| 37 |
+
data_time = AverageMeter('- data', ':4.3f')
|
| 38 |
+
batch_time = AverageMeter('- batch', ':6.3f')
|
| 39 |
+
progress = ProgressMeter(
|
| 40 |
+
len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]")
|
| 41 |
+
|
| 42 |
+
end = time.time()
|
| 43 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 44 |
+
model.to(device)
|
| 45 |
+
model.eval()
|
| 46 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 47 |
+
|
| 48 |
+
predictions = []
|
| 49 |
+
|
| 50 |
+
for batch_index, data_batch in enumerate(tqdm(dataloader)):
|
| 51 |
+
context_str_batch = data_batch
|
| 52 |
+
|
| 53 |
+
# data tokenizer
|
| 54 |
+
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt')
|
| 55 |
+
|
| 56 |
+
# to gpu
|
| 57 |
+
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()}
|
| 58 |
+
|
| 59 |
+
# forward
|
| 60 |
+
data_input_batch = context_token_batch
|
| 61 |
+
output_batch = model(**data_input_batch)
|
| 62 |
+
|
| 63 |
+
pred_batch = output_batch.softmax(dim=-1)
|
| 64 |
+
pred = torch.argmax(pred_batch, dim=-1)
|
| 65 |
+
predictions.extend(pred.cpu().numpy())
|
| 66 |
+
|
| 67 |
+
batch_time.update(time.time() - end)
|
| 68 |
+
end = time.time()
|
| 69 |
+
|
| 70 |
+
if batch_index % 50 == 0:
|
| 71 |
+
progress.print(batch_index)
|
| 72 |
+
|
| 73 |
+
return predictions
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def infer20221212():
|
| 77 |
+
checkpoint_file = '/home/elaine/Desktop/macbert_code/checkpoints_name/checkpoint_epoch015_acc1.0000.pth.tar'
|
| 78 |
+
output_file = r'/home/elaine/Desktop/macbert_code/output.csv'
|
| 79 |
+
|
| 80 |
+
cache_dir = '/home/elaine/Desktop/macbert_code/code/cache'
|
| 81 |
+
ann_file_test = r'/home/elaine/Desktop/macbert_code/dataset/name_test_8000.csv'
|
| 82 |
+
|
| 83 |
+
model_cfg = {
|
| 84 |
+
"pretrained_transformers": "hfl/chinese-macbert-base",
|
| 85 |
+
"cache_dir": cache_dir
|
| 86 |
+
}
|
| 87 |
+
# 模型
|
| 88 |
+
model_dict = get_model(model_cfg, mode='base')
|
| 89 |
+
model = model_dict['model']
|
| 90 |
+
tokenizer = model_dict['tokenizer']
|
| 91 |
+
print(model)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
data_loader_cfg = {}
|
| 95 |
+
test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test')
|
| 96 |
+
test_loader = DataLoader(test_dataset, batch_size=8, num_workers=4, pin_memory=True)
|
| 97 |
+
|
| 98 |
+
# resume
|
| 99 |
+
assert checkpoint_file is not None and os.path.exists(checkpoint_file)
|
| 100 |
+
checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
| 101 |
+
# model.load_state_dict(checkpoint['state_dict'])
|
| 102 |
+
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()})
|
| 103 |
+
print(f"=> Resume: loaded checkpoint {checkpoint_file} (epoch {checkpoint['epoch']})")
|
| 104 |
+
|
| 105 |
+
#model = model.cuda()
|
| 106 |
+
pred_res = test_epoch(model, 1, test_loader, tokenizer)
|
| 107 |
+
with open(output_file, 'w') as f:
|
| 108 |
+
for pred in pred_res:
|
| 109 |
+
f.write(f"{pred}\n")
|
| 110 |
+
|
| 111 |
+
# 讀取val.csv的label
|
| 112 |
+
import csv
|
| 113 |
+
true_labels = []
|
| 114 |
+
with open(ann_file_test, 'r') as f:
|
| 115 |
+
reader = csv.reader(f)
|
| 116 |
+
next(reader) # skip header
|
| 117 |
+
for row in reader:
|
| 118 |
+
true_labels.append(int(row[3]))
|
| 119 |
+
|
| 120 |
+
# 計算confusion matrix
|
| 121 |
+
from sklearn.metrics import confusion_matrix
|
| 122 |
+
cm = confusion_matrix(true_labels, pred_res)
|
| 123 |
+
print('Confusion Matrix:')
|
| 124 |
+
print(cm)
|
| 125 |
+
|
| 126 |
+
# 印出預測錯誤的內容、預測值和正確答案
|
| 127 |
+
with open(ann_file_test) as f:
|
| 128 |
+
reader = csv.reader(f)
|
| 129 |
+
next(reader) # skip header
|
| 130 |
+
for idx, row in enumerate(reader):
|
| 131 |
+
sms, label = row[1], int(row[3])
|
| 132 |
+
pred = pred_res[idx]
|
| 133 |
+
if pred != label:
|
| 134 |
+
print(f"錯誤: sms='{sms}',預測={pred},正確={label}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == '__main__':
|
| 138 |
+
infer20221212()
|
macbert/infer_all.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from torchvision.transforms.functional import to_tensor, to_pil_image
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from transformers import AutoModel
|
| 11 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import random
|
| 21 |
+
import numpy as np
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from rich import print
|
| 24 |
+
import time
|
| 25 |
+
from glob import glob
|
| 26 |
+
import string
|
| 27 |
+
from torch.optim import AdamW
|
| 28 |
+
from transformers import get_linear_schedule_with_warmup
|
| 29 |
+
from models import get_model
|
| 30 |
+
from dataset import MyDataset
|
| 31 |
+
from utils import save_checkpoint, AverageMeter, ProgressMeter
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_epoch(travel_model, name_model, epoch, dataloader, tokenizer):
|
| 35 |
+
print(f"\n\n=> val")
|
| 36 |
+
data_time = AverageMeter('- data', ':4.3f')
|
| 37 |
+
batch_time = AverageMeter('- batch', ':6.3f')
|
| 38 |
+
progress = ProgressMeter(
|
| 39 |
+
len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]")
|
| 40 |
+
|
| 41 |
+
end = time.time()
|
| 42 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 43 |
+
travel_model.to(device)
|
| 44 |
+
travel_model.eval()
|
| 45 |
+
|
| 46 |
+
name_model.to(device)
|
| 47 |
+
name_model.eval()
|
| 48 |
+
|
| 49 |
+
sms_ids = []
|
| 50 |
+
travel_probs = []
|
| 51 |
+
travel_predictions = []
|
| 52 |
+
name_probs = []
|
| 53 |
+
name_predictions = []
|
| 54 |
+
|
| 55 |
+
for batch_index, data_batch in enumerate(tqdm(dataloader)):
|
| 56 |
+
|
| 57 |
+
context_str_batch, sms_id = data_batch
|
| 58 |
+
sms_ids.append(sms_id.detach().cpu().numpy()[0])
|
| 59 |
+
|
| 60 |
+
# data tokenizer
|
| 61 |
+
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt')
|
| 62 |
+
|
| 63 |
+
# to gpu
|
| 64 |
+
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()}
|
| 65 |
+
|
| 66 |
+
# forward travel
|
| 67 |
+
data_input_batch = context_token_batch
|
| 68 |
+
travel_output_batch = travel_model(**data_input_batch)
|
| 69 |
+
name_output_batch = name_model(**data_input_batch)
|
| 70 |
+
|
| 71 |
+
travel_pred_batch = travel_output_batch.softmax(dim=-1)
|
| 72 |
+
travel_probs.append(travel_pred_batch.detach().cpu().numpy()[0][1])
|
| 73 |
+
travel_pred = torch.argmax(travel_pred_batch, dim=-1)
|
| 74 |
+
travel_predictions.extend(travel_pred.cpu().numpy())
|
| 75 |
+
|
| 76 |
+
# forward name
|
| 77 |
+
name_pred_batch = name_output_batch.softmax(dim=-1)
|
| 78 |
+
name_probs.append(name_pred_batch.detach().cpu().numpy()[0][1])
|
| 79 |
+
name_pred = torch.argmax(name_pred_batch, dim=-1)
|
| 80 |
+
name_predictions.extend(name_pred.cpu().numpy())
|
| 81 |
+
|
| 82 |
+
batch_time.update(time.time() - end)
|
| 83 |
+
end = time.time()
|
| 84 |
+
|
| 85 |
+
if batch_index % 50 == 0:
|
| 86 |
+
progress.print(batch_index)
|
| 87 |
+
|
| 88 |
+
return travel_predictions, travel_probs, name_predictions, name_probs, sms_ids
|
| 89 |
+
|
| 90 |
+
def inference():
|
| 91 |
+
travel_checkpoint_file = '/home/jchsiao/Desktop/macbert/checkpoints/saved_checkpoints/travel_checkpoint15_train8000.pth.tar'
|
| 92 |
+
name_checkpoint_file = '/home/jchsiao/Desktop/macbert/checkpoints/saved_checkpoints/name_checkpoint17_train9000.pth.tar'
|
| 93 |
+
ann_file_test = r'/home/jchsiao/Desktop/macbert/dataset/datagame_sms_stage1(in).csv'
|
| 94 |
+
output_file = r'/home/jchsiao/Desktop/macbert/both_macbertBase_20250731_2.csv'
|
| 95 |
+
cache_dir = 'cache'
|
| 96 |
+
|
| 97 |
+
model_cfg = {
|
| 98 |
+
"pretrained_transformers": "hfl/chinese-macbert-base",
|
| 99 |
+
"cache_dir": cache_dir
|
| 100 |
+
}
|
| 101 |
+
# 模型
|
| 102 |
+
travel_model_dict = get_model(model_cfg, mode='base')
|
| 103 |
+
travel_model = travel_model_dict['model']
|
| 104 |
+
|
| 105 |
+
name_model_dict = get_model(model_cfg, mode='base')
|
| 106 |
+
name_model = name_model_dict['model']
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
tokenizer = travel_model_dict['tokenizer']
|
| 110 |
+
# print(model)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
data_loader_cfg = {}
|
| 114 |
+
test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test')
|
| 115 |
+
test_loader = DataLoader(test_dataset, batch_size=1, pin_memory=True, shuffle=False)
|
| 116 |
+
|
| 117 |
+
# resume
|
| 118 |
+
assert travel_checkpoint_file is not None and os.path.exists(travel_checkpoint_file)
|
| 119 |
+
assert name_checkpoint_file is not None and os.path.exists(name_checkpoint_file)
|
| 120 |
+
|
| 121 |
+
travel_checkpoint = torch.load(travel_checkpoint_file, map_location='cpu')
|
| 122 |
+
name_checkpoint = torch.load(name_checkpoint_file, map_location='cpu')
|
| 123 |
+
# model.load_state_dict(checkpoint['state_dict'])
|
| 124 |
+
travel_model.load_state_dict({k.replace('module.', ''): v for k, v in travel_checkpoint['state_dict'].items()})
|
| 125 |
+
print(f"=> Resume: loaded travel checkpoint {travel_checkpoint_file} (epoch {travel_checkpoint['epoch']})")
|
| 126 |
+
|
| 127 |
+
name_model.load_state_dict({k.replace('module.', ''): v for k, v in name_checkpoint['state_dict'].items()})
|
| 128 |
+
print(f"=> Resume: loaded name checkpoint {name_checkpoint_file} (epoch {name_checkpoint['epoch']})")
|
| 129 |
+
|
| 130 |
+
#model = model.cuda()
|
| 131 |
+
travel_predictions, travel_probs, name_predictions, name_probs, sms_ids = test_epoch(travel_model, name_model, 1, test_loader, tokenizer)
|
| 132 |
+
with open(output_file, 'w') as f:
|
| 133 |
+
f.write("sms_id,travel_prob,label,name_prob,name_flg\n")
|
| 134 |
+
for travel_pred, travel_prob, name_pred, name_prob, sms_id in zip(travel_predictions, travel_probs, name_predictions, name_probs, sms_ids):
|
| 135 |
+
f.write(f"{sms_id},{travel_prob},{travel_pred},{name_prob},{name_pred}\n")
|
| 136 |
+
print('Output file saved!')
|
| 137 |
+
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
# 讀取val.csv的label
|
| 141 |
+
import csv
|
| 142 |
+
true_labels = []
|
| 143 |
+
with open(ann_file_test, 'r', encoding='utf-8') as f:
|
| 144 |
+
reader = csv.reader(f)
|
| 145 |
+
next(reader) # skip header
|
| 146 |
+
for row in reader:
|
| 147 |
+
true_labels.append(int(row[2]))
|
| 148 |
+
|
| 149 |
+
# 計算confusion matrix
|
| 150 |
+
from sklearn.metrics import confusion_matrix
|
| 151 |
+
cm = confusion_matrix(true_labels, pred_res)
|
| 152 |
+
print('Confusion Matrix:')
|
| 153 |
+
print(cm)
|
| 154 |
+
|
| 155 |
+
# 印出預測錯誤的內容、預測值和正確答案
|
| 156 |
+
with open(ann_file_test, 'r', encoding='utf-8') as f:
|
| 157 |
+
reader = csv.reader(f)
|
| 158 |
+
next(reader) # skip header
|
| 159 |
+
for idx, row in enumerate(reader):
|
| 160 |
+
id, sms, label = int(row[0]), row[1], int(row[2])
|
| 161 |
+
pred = pred_res[idx]
|
| 162 |
+
if pred != label:
|
| 163 |
+
print(f"錯誤: sms_id={id},sms='{sms}',預測={pred},正確={label}")
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
if __name__ == '__main__':
|
| 167 |
+
inference()
|
macbert/infer_name.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from torchvision.transforms.functional import to_tensor, to_pil_image
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from transformers import AutoModel
|
| 11 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import random
|
| 21 |
+
import numpy as np
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from rich import print
|
| 24 |
+
import time
|
| 25 |
+
import cv2
|
| 26 |
+
from glob import glob
|
| 27 |
+
import string
|
| 28 |
+
from torch.optim import AdamW
|
| 29 |
+
from transformers import get_linear_schedule_with_warmup
|
| 30 |
+
from models import get_model
|
| 31 |
+
from dataset import MyDataset
|
| 32 |
+
from utils import save_checkpoint, AverageMeter, ProgressMeter
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_epoch(model, epoch, dataloader, tokenizer):
|
| 36 |
+
print(f"\n\n=> val")
|
| 37 |
+
data_time = AverageMeter('- data', ':4.3f')
|
| 38 |
+
batch_time = AverageMeter('- batch', ':6.3f')
|
| 39 |
+
progress = ProgressMeter(
|
| 40 |
+
len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]")
|
| 41 |
+
|
| 42 |
+
end = time.time()
|
| 43 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 44 |
+
model.to(device)
|
| 45 |
+
model.eval()
|
| 46 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 47 |
+
|
| 48 |
+
sms_ids = []
|
| 49 |
+
probs = []
|
| 50 |
+
predictions = []
|
| 51 |
+
|
| 52 |
+
for batch_index, data_batch in enumerate(tqdm(dataloader)):
|
| 53 |
+
|
| 54 |
+
context_str_batch, sms_id = data_batch
|
| 55 |
+
sms_ids.append(sms_id.detach().cpu().numpy()[0])
|
| 56 |
+
|
| 57 |
+
# data tokenizer
|
| 58 |
+
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt')
|
| 59 |
+
|
| 60 |
+
# to gpu
|
| 61 |
+
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()}
|
| 62 |
+
|
| 63 |
+
# forward
|
| 64 |
+
data_input_batch = context_token_batch
|
| 65 |
+
output_batch = model(**data_input_batch)
|
| 66 |
+
|
| 67 |
+
pred_batch = output_batch.softmax(dim=-1)
|
| 68 |
+
probs.append(pred_batch.detach().cpu().numpy()[0][1])
|
| 69 |
+
pred = torch.argmax(pred_batch, dim=-1)
|
| 70 |
+
predictions.extend(pred.cpu().numpy())
|
| 71 |
+
|
| 72 |
+
batch_time.update(time.time() - end)
|
| 73 |
+
end = time.time()
|
| 74 |
+
|
| 75 |
+
if batch_index % 50 == 0:
|
| 76 |
+
progress.print(batch_index)
|
| 77 |
+
|
| 78 |
+
return predictions, probs, sms_ids
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def infer20221212():
|
| 82 |
+
checkpoint_file = '/home/elaine/Desktop/macbert_code/checkpoints_name/checkpoint_epoch017_acc0.9965.pth.tar'
|
| 83 |
+
output_file = r'/home/elaine/Desktop/macbert_code/name_v2_output.csv'
|
| 84 |
+
|
| 85 |
+
cache_dir = '/home/elaine/Desktop/macbert_code/cache'
|
| 86 |
+
ann_file_test = r'/home/elaine/Desktop/macbert_code/dataset/name_test_9000.csv'
|
| 87 |
+
|
| 88 |
+
model_cfg = {
|
| 89 |
+
"pretrained_transformers": "hfl/chinese-macbert-base",
|
| 90 |
+
"cache_dir": cache_dir
|
| 91 |
+
}
|
| 92 |
+
# 模型
|
| 93 |
+
model_dict = get_model(model_cfg, mode='base')
|
| 94 |
+
model = model_dict['model']
|
| 95 |
+
tokenizer = model_dict['tokenizer']
|
| 96 |
+
print(model)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
data_loader_cfg = {}
|
| 100 |
+
test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test')
|
| 101 |
+
test_loader = DataLoader(test_dataset, batch_size=1, pin_memory=True, shuffle=False)
|
| 102 |
+
|
| 103 |
+
# resume
|
| 104 |
+
assert checkpoint_file is not None and os.path.exists(checkpoint_file)
|
| 105 |
+
checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
| 106 |
+
# model.load_state_dict(checkpoint['state_dict'])
|
| 107 |
+
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()})
|
| 108 |
+
print(f"=> Resume: loaded checkpoint {checkpoint_file} (epoch {checkpoint['epoch']})")
|
| 109 |
+
|
| 110 |
+
#model = model.cuda()
|
| 111 |
+
pred_res, probs, sms_ids = test_epoch(model, 1, test_loader, tokenizer)
|
| 112 |
+
with open(output_file, 'w') as f:
|
| 113 |
+
f.write("sms_id,prob,name_flg\n")
|
| 114 |
+
for pred, prob, sms_id in zip(pred_res, probs, sms_ids):
|
| 115 |
+
f.write(f"{sms_id},{prob},{pred}\n")
|
| 116 |
+
|
| 117 |
+
# 讀取val.csv的label
|
| 118 |
+
import csv
|
| 119 |
+
true_labels = []
|
| 120 |
+
with open(ann_file_test, 'r', encoding='utf-8') as f:
|
| 121 |
+
reader = csv.reader(f)
|
| 122 |
+
next(reader) # skip header
|
| 123 |
+
for row in reader:
|
| 124 |
+
true_labels.append(int(row[3]))
|
| 125 |
+
|
| 126 |
+
# 計算confusion matrix
|
| 127 |
+
from sklearn.metrics import confusion_matrix
|
| 128 |
+
cm = confusion_matrix(true_labels, pred_res)
|
| 129 |
+
print('Confusion Matrix:')
|
| 130 |
+
print(cm)
|
| 131 |
+
|
| 132 |
+
# 印出預測錯誤的內容、預測值和正確答案
|
| 133 |
+
with open(ann_file_test, 'r', encoding='utf-8') as f:
|
| 134 |
+
reader = csv.reader(f)
|
| 135 |
+
next(reader) # skip header
|
| 136 |
+
for idx, row in enumerate(reader):
|
| 137 |
+
id, sms, label = int(row[0]), row[1], int(row[3])
|
| 138 |
+
pred = pred_res[idx]
|
| 139 |
+
if pred != label:
|
| 140 |
+
print(f"錯誤: sms_id={id},sms='{sms}',預測={pred},正確={label}")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == '__main__':
|
| 144 |
+
infer20221212()
|
macbert/infer_travel.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from torchvision.transforms.functional import to_tensor, to_pil_image
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from transformers import AutoModel
|
| 11 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import random
|
| 21 |
+
import numpy as np
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from rich import print
|
| 24 |
+
import time
|
| 25 |
+
import cv2
|
| 26 |
+
from glob import glob
|
| 27 |
+
import string
|
| 28 |
+
from torch.optim import AdamW
|
| 29 |
+
from transformers import get_linear_schedule_with_warmup
|
| 30 |
+
from models import get_model
|
| 31 |
+
from dataset import MyDataset
|
| 32 |
+
from utils import save_checkpoint, AverageMeter, ProgressMeter
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_epoch(model, epoch, dataloader, tokenizer):
|
| 36 |
+
print(f"\n\n=> val")
|
| 37 |
+
data_time = AverageMeter('- data', ':4.3f')
|
| 38 |
+
batch_time = AverageMeter('- batch', ':6.3f')
|
| 39 |
+
progress = ProgressMeter(
|
| 40 |
+
len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]")
|
| 41 |
+
|
| 42 |
+
end = time.time()
|
| 43 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 44 |
+
model.to(device)
|
| 45 |
+
model.eval()
|
| 46 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 47 |
+
|
| 48 |
+
sms_ids = []
|
| 49 |
+
probs = []
|
| 50 |
+
predictions = []
|
| 51 |
+
|
| 52 |
+
for batch_index, data_batch in enumerate(tqdm(dataloader)):
|
| 53 |
+
|
| 54 |
+
context_str_batch, sms_id = data_batch
|
| 55 |
+
sms_ids.append(sms_id.detach().cpu().numpy()[0])
|
| 56 |
+
|
| 57 |
+
# data tokenizer
|
| 58 |
+
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt')
|
| 59 |
+
|
| 60 |
+
# to gpu
|
| 61 |
+
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()}
|
| 62 |
+
|
| 63 |
+
# forward
|
| 64 |
+
data_input_batch = context_token_batch
|
| 65 |
+
output_batch = model(**data_input_batch)
|
| 66 |
+
|
| 67 |
+
pred_batch = output_batch.softmax(dim=-1)
|
| 68 |
+
probs.append(pred_batch.detach().cpu().numpy()[0][1])
|
| 69 |
+
pred = torch.argmax(pred_batch, dim=-1)
|
| 70 |
+
predictions.extend(pred.cpu().numpy())
|
| 71 |
+
|
| 72 |
+
batch_time.update(time.time() - end)
|
| 73 |
+
end = time.time()
|
| 74 |
+
|
| 75 |
+
if batch_index % 50 == 0:
|
| 76 |
+
progress.print(batch_index)
|
| 77 |
+
|
| 78 |
+
return predictions, probs, sms_ids
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def infer20221212():
|
| 82 |
+
checkpoint_file = '/home/elaine/Desktop/macbert_code/checkpoints_travel/checkpoint_epoch017_acc0.9988.pth.tar'
|
| 83 |
+
output_file = r'/home/elaine/Desktop/macbert_code/travel_v2_output.csv'
|
| 84 |
+
|
| 85 |
+
cache_dir = '/home/elaine/Desktop/macbert_code/cache'
|
| 86 |
+
ann_file_test = r'/home/elaine/Desktop/macbert_code/dataset/travel_test_9000.csv'
|
| 87 |
+
|
| 88 |
+
model_cfg = {
|
| 89 |
+
"pretrained_transformers": "hfl/chinese-macbert-base",
|
| 90 |
+
"cache_dir": cache_dir
|
| 91 |
+
}
|
| 92 |
+
# 模型
|
| 93 |
+
model_dict = get_model(model_cfg, mode='base')
|
| 94 |
+
model = model_dict['model']
|
| 95 |
+
tokenizer = model_dict['tokenizer']
|
| 96 |
+
print(model)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
data_loader_cfg = {}
|
| 100 |
+
test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test')
|
| 101 |
+
test_loader = DataLoader(test_dataset, batch_size=1, pin_memory=True, shuffle=False)
|
| 102 |
+
|
| 103 |
+
# resume
|
| 104 |
+
assert checkpoint_file is not None and os.path.exists(checkpoint_file)
|
| 105 |
+
checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
| 106 |
+
# model.load_state_dict(checkpoint['state_dict'])
|
| 107 |
+
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()})
|
| 108 |
+
print(f"=> Resume: loaded checkpoint {checkpoint_file} (epoch {checkpoint['epoch']})")
|
| 109 |
+
|
| 110 |
+
#model = model.cuda()
|
| 111 |
+
pred_res, probs, sms_ids = test_epoch(model, 1, test_loader, tokenizer)
|
| 112 |
+
with open(output_file, 'w') as f:
|
| 113 |
+
f.write("sms_id,prob,label\n")
|
| 114 |
+
for pred, prob, sms_id in zip(pred_res, probs, sms_ids):
|
| 115 |
+
f.write(f"{sms_id},{prob},{pred}\n")
|
| 116 |
+
|
| 117 |
+
# 讀取val.csv的label
|
| 118 |
+
import csv
|
| 119 |
+
true_labels = []
|
| 120 |
+
with open(ann_file_test, 'r', encoding='utf-8') as f:
|
| 121 |
+
reader = csv.reader(f)
|
| 122 |
+
next(reader) # skip header
|
| 123 |
+
for row in reader:
|
| 124 |
+
true_labels.append(int(row[2]))
|
| 125 |
+
|
| 126 |
+
# 計算confusion matrix
|
| 127 |
+
from sklearn.metrics import confusion_matrix
|
| 128 |
+
cm = confusion_matrix(true_labels, pred_res)
|
| 129 |
+
print('Confusion Matrix:')
|
| 130 |
+
print(cm)
|
| 131 |
+
|
| 132 |
+
# 印出預測錯誤的內容、預測值和正確答案
|
| 133 |
+
with open(ann_file_test, 'r', encoding='utf-8') as f:
|
| 134 |
+
reader = csv.reader(f)
|
| 135 |
+
next(reader) # skip header
|
| 136 |
+
for idx, row in enumerate(reader):
|
| 137 |
+
id, sms, label = int(row[0]), row[1], int(row[2])
|
| 138 |
+
pred = pred_res[idx]
|
| 139 |
+
if pred != label:
|
| 140 |
+
print(f"錯誤: sms_id={id},sms='{sms}',預測={pred},正確={label}")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == '__main__':
|
| 144 |
+
infer20221212()
|
macbert/main.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from torchvision.transforms.functional import to_tensor, to_pil_image
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from transformers import AutoModel
|
| 11 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
from torch.utils.data import RandomSampler, SequentialSampler
|
| 19 |
+
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import random
|
| 22 |
+
import numpy as np
|
| 23 |
+
# from collections import OrderedDict
|
| 24 |
+
from rich import print
|
| 25 |
+
import time
|
| 26 |
+
import cv2
|
| 27 |
+
# from glob import glob
|
| 28 |
+
import string
|
| 29 |
+
from torch.optim import AdamW
|
| 30 |
+
from transformers import get_linear_schedule_with_warmup
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
from models import get_model
|
| 34 |
+
from dataset import MyDataset
|
| 35 |
+
from utils import save_checkpoint, AverageMeter, ProgressMeter
|
| 36 |
+
|
| 37 |
+
# if __name__ == '__main__':
|
| 38 |
+
# torch.distributed.init_process_group(backend="nccl")
|
| 39 |
+
# local_rank = torch.distributed.get_rank()
|
| 40 |
+
# torch.cuda.set_device(local_rank)
|
| 41 |
+
# device = torch.device("cuda", local_rank)
|
| 42 |
+
# scaler = GradScaler()
|
| 43 |
+
# else:
|
| 44 |
+
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 45 |
+
|
| 46 |
+
if __name__ == '__main__':
|
| 47 |
+
# 檢查是否為分散式訓練模式(例如 torchrun 啟動)
|
| 48 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 49 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 50 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 51 |
+
torch.cuda.set_device(local_rank)
|
| 52 |
+
device = torch.device("cuda", local_rank)
|
| 53 |
+
print(f"[Distributed] Rank {os.environ['RANK']} using device {local_rank}")
|
| 54 |
+
else:
|
| 55 |
+
# 單機單卡訓練模式
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
print(f"[Single] Using device: {device}")
|
| 58 |
+
|
| 59 |
+
# AMP scaler 建議新版寫法
|
| 60 |
+
scaler = torch.amp.GradScaler(device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 61 |
+
|
| 62 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 63 |
+
print_raw = print
|
| 64 |
+
def print(*info):
|
| 65 |
+
if local_rank == 0:
|
| 66 |
+
print_raw(*info)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def crossentropy(y_true, y_pred):
|
| 72 |
+
return F.cross_entropy(y_pred, y_true, label_smoothing=0.2)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def evaluate(predictions, labels):
|
| 76 |
+
nb_all = len(predictions)
|
| 77 |
+
acc = sum([int(p==l) for p, l in zip(predictions, labels)]) / (nb_all + 1e-8)
|
| 78 |
+
|
| 79 |
+
eval_results = {'acc': acc}
|
| 80 |
+
|
| 81 |
+
return eval_results
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def train_epoch(model, optimizer, epoch, dataloader, sampler, tokenizer, scheduler):
|
| 85 |
+
print(f"\n\n=> train")
|
| 86 |
+
data_time = AverageMeter('- data', ':4.3f')
|
| 87 |
+
batch_time = AverageMeter('- batch', ':6.3f')
|
| 88 |
+
losses = AverageMeter('- loss', ':.4e')
|
| 89 |
+
acces = AverageMeter('- acc', ':.4f')
|
| 90 |
+
progress = ProgressMeter(
|
| 91 |
+
len(dataloader), data_time, batch_time, losses, acces, prefix=f"Epoch: [{epoch}]")
|
| 92 |
+
|
| 93 |
+
end = time.time()
|
| 94 |
+
model.train()
|
| 95 |
+
if hasattr(sampler, "set_epoch"):
|
| 96 |
+
sampler.set_epoch(epoch)
|
| 97 |
+
|
| 98 |
+
predictions, labels = [], []
|
| 99 |
+
|
| 100 |
+
for batch_index, data_batch in enumerate(dataloader):
|
| 101 |
+
optimizer.zero_grad()
|
| 102 |
+
|
| 103 |
+
context_str_batch, target_batch = data_batch
|
| 104 |
+
|
| 105 |
+
# data tokenizer
|
| 106 |
+
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt')
|
| 107 |
+
|
| 108 |
+
# to gpu
|
| 109 |
+
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()}
|
| 110 |
+
target_batch = target_batch.to(device)
|
| 111 |
+
|
| 112 |
+
# forward
|
| 113 |
+
data_input_batch = context_token_batch
|
| 114 |
+
output_batch = model(**data_input_batch)
|
| 115 |
+
|
| 116 |
+
pred_batch = output_batch.softmax(dim=-1)
|
| 117 |
+
|
| 118 |
+
loss_batch = crossentropy(target_batch, output_batch)
|
| 119 |
+
loss = torch.mean(loss_batch)
|
| 120 |
+
# print(loss)
|
| 121 |
+
loss.backward()
|
| 122 |
+
optimizer.step()
|
| 123 |
+
if scheduler is not None:
|
| 124 |
+
scheduler.step()
|
| 125 |
+
|
| 126 |
+
loss_value = loss.item()
|
| 127 |
+
losses.update(loss_value, len(target_batch))
|
| 128 |
+
pred = torch.argmax(pred_batch, dim=-1)
|
| 129 |
+
predictions.extend(pred.cpu().numpy())
|
| 130 |
+
labels.extend(target_batch.cpu().numpy())
|
| 131 |
+
acc_batch = (target_batch==pred).sum().cpu().numpy() / (len(target_batch) + 1e-8)
|
| 132 |
+
acces.update(acc_batch, len(target_batch))
|
| 133 |
+
batch_time.update(time.time() - end)
|
| 134 |
+
end = time.time()
|
| 135 |
+
|
| 136 |
+
if batch_index % 50 == 0:
|
| 137 |
+
progress.print(batch_index)
|
| 138 |
+
|
| 139 |
+
results = evaluate(predictions, labels)
|
| 140 |
+
print(results)
|
| 141 |
+
return results
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def val_epoch(model, optimizer, epoch, dataloader, sampler, tokenizer):
|
| 145 |
+
print(f"\n\n=> val")
|
| 146 |
+
data_time = AverageMeter('- data', ':4.3f')
|
| 147 |
+
batch_time = AverageMeter('- batch', ':6.3f')
|
| 148 |
+
losses = AverageMeter('- loss', ':.4e')
|
| 149 |
+
acces = AverageMeter('- acc', ':.4f')
|
| 150 |
+
progress = ProgressMeter(
|
| 151 |
+
len(dataloader), data_time, batch_time, losses, acces, prefix=f"Epoch: [{epoch}]")
|
| 152 |
+
|
| 153 |
+
end = time.time()
|
| 154 |
+
model.train()
|
| 155 |
+
if hasattr(sampler, "set_epoch"):
|
| 156 |
+
sampler.set_epoch(epoch)
|
| 157 |
+
|
| 158 |
+
predictions, labels = [], []
|
| 159 |
+
|
| 160 |
+
for batch_index, data_batch in enumerate(dataloader):
|
| 161 |
+
optimizer.zero_grad()
|
| 162 |
+
|
| 163 |
+
context_str_batch, target_batch = data_batch
|
| 164 |
+
|
| 165 |
+
# data tokenizer
|
| 166 |
+
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt')
|
| 167 |
+
|
| 168 |
+
# to gpu
|
| 169 |
+
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()}
|
| 170 |
+
target_batch = target_batch.to(device)
|
| 171 |
+
|
| 172 |
+
# forward
|
| 173 |
+
data_input_batch = context_token_batch
|
| 174 |
+
output_batch = model(**data_input_batch)
|
| 175 |
+
|
| 176 |
+
pred_batch = output_batch.softmax(dim=-1)
|
| 177 |
+
|
| 178 |
+
loss_batch = crossentropy(target_batch, output_batch)
|
| 179 |
+
loss = torch.mean(loss_batch)
|
| 180 |
+
# print(pred_batch)
|
| 181 |
+
# print(target_batch)
|
| 182 |
+
# print(loss)
|
| 183 |
+
|
| 184 |
+
loss_value = loss.item()
|
| 185 |
+
losses.update(loss_value, len(target_batch))
|
| 186 |
+
pred = torch.argmax(pred_batch, dim=-1)
|
| 187 |
+
predictions.extend(pred.cpu().numpy())
|
| 188 |
+
labels.extend(target_batch.cpu().numpy())
|
| 189 |
+
acc_batch = (target_batch==pred).sum().cpu().numpy() / (len(target_batch) + 1e-8)
|
| 190 |
+
acces.update(acc_batch, len(target_batch))
|
| 191 |
+
batch_time.update(time.time() - end)
|
| 192 |
+
end = time.time()
|
| 193 |
+
|
| 194 |
+
if batch_index % 50 == 0:
|
| 195 |
+
progress.print(batch_index)
|
| 196 |
+
|
| 197 |
+
results = evaluate(predictions, labels)
|
| 198 |
+
print(results)
|
| 199 |
+
return results
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def gogogo():
|
| 204 |
+
|
| 205 |
+
output_dir = '/home/elaine/Desktop/macbert_code/checkpoints_travel'
|
| 206 |
+
ann_file_tra = '/home/elaine/Desktop/macbert_code/dataset/travel_train_9000.csv'
|
| 207 |
+
ann_file_val = '/home/elaine/Desktop/macbert_code/dataset/travel_val_9000.csv'
|
| 208 |
+
checkpoint_file = None
|
| 209 |
+
|
| 210 |
+
batch_size = 4
|
| 211 |
+
epochs = 20
|
| 212 |
+
cache_dir = ' /home/elaine/Desktop/macbert_code/cache'
|
| 213 |
+
|
| 214 |
+
model_cfg = {
|
| 215 |
+
"pretrained_transformers": "hfl/chinese-macbert-base",
|
| 216 |
+
"cache_dir": cache_dir
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
# 模型與 tokenizer
|
| 220 |
+
model_dict = get_model(model_cfg, mode='base')
|
| 221 |
+
model = model_dict['model']
|
| 222 |
+
tokenizer = model_dict['tokenizer']
|
| 223 |
+
print(model)
|
| 224 |
+
|
| 225 |
+
# 優化器參數設計
|
| 226 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
| 227 |
+
optimizer_grouped_parameters = [
|
| 228 |
+
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
| 229 |
+
'weight_decay': 0.01},
|
| 230 |
+
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
| 231 |
+
'weight_decay': 0.0}
|
| 232 |
+
]
|
| 233 |
+
optimizer = AdamW(model.parameters(), lr=1e-5, eps=1e-8)
|
| 234 |
+
scheduler = None # 如果你需要可啟用
|
| 235 |
+
|
| 236 |
+
# Dataset 與 DataLoader(單卡不使用 DistributedSampler)
|
| 237 |
+
data_loader_cfg = {}
|
| 238 |
+
tra_dataset = MyDataset(ann_file_tra, data_loader_cfg, mode='tra')
|
| 239 |
+
val_dataset = MyDataset(ann_file_val, {}, mode='val')
|
| 240 |
+
|
| 241 |
+
# Sampler(單卡用 RandomSampler / SequentialSampler)
|
| 242 |
+
sampler_tra = RandomSampler(tra_dataset)
|
| 243 |
+
sampler_val = SequentialSampler(val_dataset)
|
| 244 |
+
|
| 245 |
+
tra_loader = DataLoader(tra_dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True)
|
| 246 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=False)
|
| 247 |
+
|
| 248 |
+
# checkpoint resume
|
| 249 |
+
if checkpoint_file is not None and os.path.exists(checkpoint_file):
|
| 250 |
+
checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
| 251 |
+
init_epoch = checkpoint['epoch'] + 1
|
| 252 |
+
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()})
|
| 253 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 254 |
+
if torch.cuda.is_available():
|
| 255 |
+
for state in optimizer.state.values():
|
| 256 |
+
for k, v in state.items():
|
| 257 |
+
if torch.is_tensor(v):
|
| 258 |
+
state[k] = v.cuda()
|
| 259 |
+
print(f"=> Resume: loaded checkpoint {checkpoint_file} (epoch {checkpoint['epoch']})")
|
| 260 |
+
else:
|
| 261 |
+
init_epoch = 1
|
| 262 |
+
print(f"=> No checkpoint. ")
|
| 263 |
+
|
| 264 |
+
# 將模型送上 GPU
|
| 265 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 266 |
+
model = model.to(device)
|
| 267 |
+
|
| 268 |
+
# 開始訓練
|
| 269 |
+
acc = 0.
|
| 270 |
+
for epoch in range(init_epoch, epochs + 1):
|
| 271 |
+
results_tra = train_epoch(model, optimizer, epoch, tra_loader, sampler_tra, tokenizer, scheduler)
|
| 272 |
+
results_val = val_epoch(model, optimizer, epoch, val_loader, sampler_val, tokenizer)
|
| 273 |
+
acc_val = results_val['acc']
|
| 274 |
+
if acc_val >= acc:
|
| 275 |
+
acc = acc_val
|
| 276 |
+
save_checkpoint({
|
| 277 |
+
'epoch': epoch,
|
| 278 |
+
'state_dict': model.state_dict(),
|
| 279 |
+
'best_acc': acc,
|
| 280 |
+
'optimizer': optimizer.state_dict(),
|
| 281 |
+
}, outname=f'{output_dir}/checkpoint_epoch{epoch:03d}_acc{acc:.4f}.pth.tar', local_rank=0)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
if __name__ == '__main__':
|
| 285 |
+
gogogo()
|
macbert/models.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
| 3 |
+
from transformers import BertModel, BertTokenizer
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
import copy
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MyModel(nn.Module):
|
| 12 |
+
def __init__(self, nlp_model, mode):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.nlp_model = nlp_model
|
| 15 |
+
if mode == 'large':
|
| 16 |
+
nb_feature = 1024
|
| 17 |
+
else:
|
| 18 |
+
nb_feature = 768
|
| 19 |
+
|
| 20 |
+
self.cls = nn.Linear(nb_feature, 2)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def forward(self, **data):
|
| 24 |
+
x = self.nlp_model(**data)
|
| 25 |
+
y = self.cls(x['last_hidden_state'][:, 0])
|
| 26 |
+
return y
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_model(cfg, mode='large'):
|
| 30 |
+
# tokenizer = AutoTokenizer.from_pretrained(cfg['common_encoder_pretrained_transformers'], additional_special_tokens=added_token)
|
| 31 |
+
# tokenizer = BertTokenizer.from_pretrained(cfg['common_encoder_pretrained_transformers'], use_fast=True, cache_dir='/share/wangqixun/workspace/bs/tx_mm/code/cache', additional_special_tokens=added_token)
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(cfg['pretrained_transformers'], use_fast=True, cache_dir=cfg['cache_dir'])
|
| 33 |
+
|
| 34 |
+
nlp_model = AutoModel.from_pretrained(cfg['pretrained_transformers'], cache_dir=cfg['cache_dir'])
|
| 35 |
+
model = MyModel(nlp_model=nlp_model, mode=mode)
|
| 36 |
+
|
| 37 |
+
return_dict = {
|
| 38 |
+
'model': model,
|
| 39 |
+
'tokenizer': tokenizer,
|
| 40 |
+
}
|
| 41 |
+
return return_dict
|
| 42 |
+
|
macbert/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
transformers
|
| 4 |
+
rich
|
| 5 |
+
pandas
|
macbert/test.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
ann_file = '/home/jchsiao/Desktop/macbert_code/dataset/travel_train_8000.csv'
|
| 5 |
+
data = np.array(pd.read_csv(ann_file))
|
| 6 |
+
print(data[1, 1])
|
macbert/utils.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
from rich import print
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AverageMeter(object):
|
| 10 |
+
"""Computes and stores the average and current value"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, name, fmt=':f'):
|
| 13 |
+
self.name = name
|
| 14 |
+
self.fmt = fmt
|
| 15 |
+
self.val = 0
|
| 16 |
+
self.avg = 0
|
| 17 |
+
self.sum = 0
|
| 18 |
+
self.count = 0
|
| 19 |
+
self.reset()
|
| 20 |
+
|
| 21 |
+
def reset(self):
|
| 22 |
+
self.val = 0
|
| 23 |
+
self.avg = 0
|
| 24 |
+
self.sum = 0
|
| 25 |
+
self.count = 0
|
| 26 |
+
|
| 27 |
+
def update(self, val, n=1):
|
| 28 |
+
self.val = val
|
| 29 |
+
self.sum += val * n
|
| 30 |
+
self.count += n
|
| 31 |
+
self.avg = self.sum / self.count
|
| 32 |
+
|
| 33 |
+
def __str__(self):
|
| 34 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
| 35 |
+
return fmtstr.format(**self.__dict__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ProgressMeter(object):
|
| 39 |
+
def __init__(self, num_batches, *meters, prefix=""):
|
| 40 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
| 41 |
+
self.meters = meters
|
| 42 |
+
self.prefix = prefix
|
| 43 |
+
|
| 44 |
+
def print(self, batch):
|
| 45 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
| 46 |
+
entries += [str(meter) for meter in self.meters]
|
| 47 |
+
print('\t'.join(entries))
|
| 48 |
+
|
| 49 |
+
def _get_batch_fmtstr(self, num_batches):
|
| 50 |
+
num_digits = len(str(num_batches // 1))
|
| 51 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
| 52 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def save_checkpoint(state, outname, local_rank):
|
| 56 |
+
if local_rank == 0:
|
| 57 |
+
# best_acc = state['best_acc']
|
| 58 |
+
# epoch = state['epoch']
|
| 59 |
+
# filename = 'checkpoint_acc_%.4f_epoch_%02d.pth.tar' % (best_acc, epoch)
|
| 60 |
+
filename = outname
|
| 61 |
+
# filename = 'checkpoint_best_%d.pth.tar'
|
| 62 |
+
# filename = os.path.join('output/', filename)
|
| 63 |
+
dir_name = os.path.dirname(filename)
|
| 64 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 65 |
+
torch.save(state, filename)
|
| 66 |
+
|
| 67 |
+
# best_filename = os.path.join(model_dir, 'checkpoint_best_%d.pth.tar' % name_no)
|
| 68 |
+
# best_filename = filename
|
| 69 |
+
# shutil.copyfile(filename, best_filename)
|
| 70 |
+
print('=> Save model to %s' % filename)
|