| | import pandas as pd |
| | import numpy as np |
| | import torch |
| | from sklearn.metrics import f1_score, classification_report |
| | from datasets import Dataset |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer |
| |
|
| | |
| | MODEL_PATH = "./final_model_best_macro" |
| | VALID_FILE = "/tmp/home/wzh/file/val_data.csv" |
| |
|
| | |
| | print(f"正在加载模型: {MODEL_PATH} ...") |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device) |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
| |
|
| | print(f"正在加载验证集: {VALID_FILE} ...") |
| | val_df = pd.read_csv(VALID_FILE) |
| | label_map = {"real": 0, "fake": 1} |
| | val_df['label'] = val_df['label'].map(label_map) |
| |
|
| | |
| | val_dataset = Dataset.from_pandas(val_df) |
| | def tokenize(examples): |
| | return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) |
| | val_dataset = val_dataset.map(tokenize, batched=True) |
| | val_dataset = val_dataset.remove_columns(["id", "text", "__index_level_0__"] if "__index_level_0__" in val_df.columns else ["id", "text"]) |
| | val_dataset = val_dataset.rename_column("label", "labels") |
| |
|
| | |
| | print("正在进行预测...") |
| | trainer = Trainer(model=model) |
| | predictions = trainer.predict(val_dataset) |
| |
|
| | |
| | logits = torch.tensor(predictions.predictions) |
| | probs = torch.nn.functional.softmax(logits, dim=-1) |
| | |
| | fake_probs = probs[:, 1].numpy() |
| | true_labels = predictions.label_ids |
| |
|
| | |
| | print("\n开始搜索最佳阈值 (Threshold Search)...") |
| | best_f1 = 0 |
| | best_thresh = 0.5 |
| |
|
| | |
| | for thresh in np.arange(0.1, 0.91, 0.01): |
| | |
| | preds = (fake_probs > thresh).astype(int) |
| | current_f1 = f1_score(true_labels, preds, average='macro') |
| | |
| | if current_f1 > best_f1: |
| | best_f1 = current_f1 |
| | best_thresh = thresh |
| |
|
| | print("\n" + "="*40) |
| | print(f"🎉 搜索完成!") |
| | print(f"默认阈值 (0.50) Macro-F1: {f1_score(true_labels, (fake_probs > 0.5).astype(int), average='macro'):.4f}") |
| | print(f"🏆 最佳阈值: {best_thresh:.2f}") |
| | print(f"🚀 优化后 Macro-F1: {best_f1:.4f}") |
| | print("="*40) |
| |
|
| | |
| | final_preds = (fake_probs > best_thresh).astype(int) |
| | print("\n最佳阈值下的详细报告:") |
| | print(classification_report(true_labels, final_preds, target_names=["real", "fake"])) |