File size: 806 Bytes
0ccd07a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import pandas as pd

# read dataset
df = pd.read_csv('toutiao_cat_data.txt',
            sep='_!_', lineterminator='\n',
            encoding='utf8',
            names=["id", "type", "type_text", "text", "keywords"])
df = df[["text", "type"]]
df["type"] = df["type"] - 100

# split dataset
df = df.sample(frac=1)
train_df, test_df = df[:-1000], df[-1000:]

# create model
from simpletransformers.classification import ClassificationModel
model = ClassificationModel(
    "bert",
    "bert-base-chinese",
    num_labels=18,
    args={"reprocess_input_data": True, "overwrite_output_dir": True},
)

# train
model.train_model(train_df)

# eval
import sklearn
result = model.eval_model(test_df, acc=sklearn.metrics.accuracy_score)
result[0]

# predict
model.predict(["M2处理器IPad mini7值得期待吗?"])