File size: 4,334 Bytes
6bce1f7 d73e9f0 6bce1f7 2ffe0f4 6bce1f7 abc7a8b 6bce1f7 0fd83b3 6bce1f7 abc7a8b 6bce1f7 d73e9f0 6bce1f7 abc7a8b 6bce1f7 2ffe0f4 6bce1f7 2ffe0f4 6bce1f7 2fb1a82 6bce1f7 2fb1a82 6bce1f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# coding: UTF-8
import spaces
import os
import time
import torch
import argparse
import numpy as np
from tqdm import tqdm
from train import train
from config import Config
from preprocess import DataProcessor, get_time_dif
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
from ebart import PegasusSummarizer
parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
#parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
parser.add_argument("--input_file", type=str, default="./data_12345/test.txt", help="input file to be predicted")
args = parser.parse_args()
args.pretrained_bert_dir = "bert-base-chinese"
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
@spaces.GPU
def main():
set_seed(args.seed)
config = Config(args.data_dir)
tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
model = BertForSequenceClassification.from_pretrained(args.pretrained_bert_dir,
config=bert_config
)
model.to(config.device)
#
summarizerModel = PegasusSummarizer()
if args.mode == "train":
print("loading data...")
start_time = time.time()
train_iterator = DataProcessor(config.train_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed)
dev_iterator = DataProcessor(config.dev_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed)
time_dif = get_time_dif(start_time)
print("time usage:", time_dif)
# train
train(model, config, train_iterator, dev_iterator)
elif args.mode == "demo":
model.load_state_dict(torch.load(config.saved_model))
model.eval()
while True:
sentence = input("please input txt:\n")
inputs = tokenizer(
sentence,
max_length=config.max_seq_len,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to(config.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs[0]
label = torch.max(logits.data, 1)[1].tolist()
print("Classification result:" + config.label_list[label[0]])
flag = str(input("continue? (y/n):"))
if flag == "Y" or flag == "y":
continue
else:
break
else:
model.load_state_dict(torch.load(config.saved_model))
model.eval()
text = []
with open(args.input_file, mode="r", encoding="UTF-8") as f:
for line in tqdm(f):
sentence = line.strip()
if not sentence: continue
text.append(sentence)
num_samples = len(text)
num_batches = (num_samples - 1) // config.batch_size + 1
for i in range(num_batches):
start = i * config.batch_size
end = min(num_samples, (i + 1) * config.batch_size)
inputs = tokenizer.batch_encode_plus(
text[start: end],
padding=True,
max_length=config.max_seq_len,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to(config.device)
outputs = model(**inputs)
logits = outputs[0]
preds = torch.max(logits.data, 1)[1].tolist()
labels = [config.label_list[_] for _ in preds]
for j in range(start, end):
print("%s\t%s" % (text[j], labels[j - start]))
if __name__ == "__main__":
main()
|