import pandas as pd # read dataset df = pd.read_csv('toutiao_cat_data.txt', sep='_!_', lineterminator='\n', encoding='utf8', names=["id", "type", "type_text", "text", "keywords"]) df = df[["text", "type"]] df["type"] = df["type"] - 100 # split dataset df = df.sample(frac=1) train_df, test_df = df[:-1000], df[-1000:] # create model from simpletransformers.classification import ClassificationModel model = ClassificationModel( "bert", "bert-base-chinese", num_labels=18, args={"reprocess_input_data": True, "overwrite_output_dir": True}, ) # train model.train_model(train_df) # eval import sklearn result = model.eval_model(test_df, acc=sklearn.metrics.accuracy_score) result[0] # predict model.predict(["M2处理器IPad mini7值得期待吗?"])