File size: 4,334 Bytes
6bce1f7
 
d73e9f0
6bce1f7
 
 
 
 
 
 
 
 
 
 
2ffe0f4
 
 
6bce1f7
 
 
abc7a8b
6bce1f7
0fd83b3
6bce1f7
 
abc7a8b
 
6bce1f7
 
 
 
 
 
d73e9f0
6bce1f7
 
 
 
 
 
abc7a8b
6bce1f7
 
 
 
2ffe0f4
 
 
6bce1f7
 
 
2ffe0f4
 
6bce1f7
 
 
 
 
 
 
 
 
 
2fb1a82
6bce1f7
 
 
 
 
 
 
 
 
 
2fb1a82
6bce1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# coding: UTF-8

import spaces
import os
import time
import torch
import argparse
import numpy as np
from tqdm import tqdm
from train import train
from config import Config
from preprocess import DataProcessor, get_time_dif
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification

from ebart import PegasusSummarizer


parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
#parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
parser.add_argument("--input_file", type=str, default="./data_12345/test.txt", help="input file to be predicted")
args = parser.parse_args()

args.pretrained_bert_dir = "bert-base-chinese"

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

@spaces.GPU
def main():
    set_seed(args.seed)
    config = Config(args.data_dir)

    tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
    bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
    model = BertForSequenceClassification.from_pretrained(args.pretrained_bert_dir,
        config=bert_config
    )
    model.to(config.device)

    #
    summarizerModel = PegasusSummarizer()

    if args.mode == "train":
        print("loading data...")
        start_time = time.time()
        train_iterator = DataProcessor(config.train_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed)
        dev_iterator = DataProcessor(config.dev_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed)
        time_dif = get_time_dif(start_time)
        print("time usage:", time_dif)

        # train
        train(model, config, train_iterator, dev_iterator)
    
    elif args.mode == "demo":
        model.load_state_dict(torch.load(config.saved_model))
        model.eval()
        while True:
            sentence = input("please input txt:\n")
            inputs = tokenizer(
                sentence, 
                max_length=config.max_seq_len,
                truncation="longest_first",
                return_tensors="pt")
            inputs = inputs.to(config.device)
            with torch.no_grad():
                outputs = model(**inputs)
                logits = outputs[0]
                label = torch.max(logits.data, 1)[1].tolist()
                print("Classification result:" + config.label_list[label[0]])
            flag = str(input("continue? (y/n):"))
            if flag == "Y" or flag == "y":
                continue
            else:
                break
    else:
        model.load_state_dict(torch.load(config.saved_model))
        model.eval()

        text = []
        with open(args.input_file, mode="r", encoding="UTF-8") as f:
            for line in tqdm(f):
                sentence = line.strip()
                if not sentence:    continue
                text.append(sentence)

        num_samples = len(text)
        num_batches = (num_samples - 1) // config.batch_size + 1
        for i in range(num_batches):
            start = i * config.batch_size
            end = min(num_samples, (i + 1) * config.batch_size)
            inputs = tokenizer.batch_encode_plus(
                text[start: end],
                padding=True,
                max_length=config.max_seq_len,
                truncation="longest_first",
                return_tensors="pt")
            inputs = inputs.to(config.device)

            outputs = model(**inputs)
            logits = outputs[0]

            preds = torch.max(logits.data, 1)[1].tolist()
            labels = [config.label_list[_] for _ in preds]
            for j in range(start, end):
                print("%s\t%s" % (text[j], labels[j - start]))
                

if __name__ == "__main__":
    main()