Spaces:
Runtime error
Runtime error
File size: 5,184 Bytes
8274db5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, confusion_matrix, roc_auc_score
from sklearn.preprocessing import label_binarize
import json
import sys
sys.path.append('.')
import pmc_clip
from training.params import parse_args
from training.data import PmcDataset
from training.fusion_method import convert_model_to_cls
# 标签映射
LABEL_MAP = {
"Irritant dermatitis": 0,
"Allergic contact dermatitis": 1,
"Mechanical injury": 2,
"Folliculitis": 3,
"Fungal infection": 4,
"Skin hyperplasia": 5,
"Parastomal varices": 6,
"Urate crystals": 7,
"Cancerous metastasis": 8,
"Pyoderma gangrenosum": 9,
"Normal": 10
}
REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
def main():
# 创建输出目录
output_dir = './evaluation_results_pmc_clip_cat'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 加载模型配置
model_path = "logs/0321-Stoma-clip-train-cls/2025_03_21-23_45_18-model_RN50_fusion4-lr_1e-05-b_256-j_8-p_amp/checkpoints/epoch_150.pt"
model_name = "RN50_fusion4"
args = parse_args()
args.model = model_name
args.pretrained = model_path
args.device = device
args.mlm = True
args.train_data = "data/single_symptoms_test.jsonl"
args.image_dir = "./data/cleaned_data"
args.csv_img_key = "image"
args.csv_caption_key = "caption"
args.context_length = 77
args.num_classes = len(LABEL_MAP)
args.output_dir = output_dir
# 创建模型和预处理函数
model, _, preprocess = pmc_clip.create_model_and_transforms(args)
model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='concat')
# 加载模型权重
state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
state_dict_real = {}
for k, v in state_dict['state_dict'].items():
state_dict_real[k.replace("module.", "", 1)] = v
print(model.load_state_dict(state_dict_real))
model.to(device=device)
# 准备数据集
dataset = PmcDataset(args,
input_filename=args.train_data,
transforms=preprocess,
is_train=False)
test_loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
print(f"测试集样本数: {len(dataset)}")
# 收集预测结果
all_preds = []
all_probs = []
all_labels = []
print("开始评估...")
model.eval()
with torch.no_grad():
for batch in tqdm(test_loader):
labels = batch["cls_label"].to(device)
# 前向传播
outputs = model(batch)
# 获取预测结果
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, dim=1)
all_preds.extend(preds.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# 转换为numpy数组
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)
all_labels = np.array(all_labels)
# 计算整体AUC(使用one-vs-rest策略的平均)
try:
y_true_bin = label_binarize(all_labels, classes=range(args.num_classes))
if args.num_classes == 2:
overall_fpr, overall_tpr, _ = roc_curve(y_true_bin[:, 1], all_probs[:, 1])
overall_auc = roc_auc_score(y_true_bin, all_probs[:, 1])
else:
overall_fpr, overall_tpr, _ = roc_curve(y_true_bin.ravel(), all_probs.ravel())
overall_auc = roc_auc_score(y_true_bin, all_probs, multi_class='ovr', average='micro')
except Exception as e:
print(f"计算整体AUC时出错: {e}")
return
# 保存整体ROC曲线数据
roc_data = {
"fpr": overall_fpr.tolist(),
"tpr": overall_tpr.tolist(),
"auc": overall_auc
}
roc_file = os.path.join(output_dir, "overall_roc_data.json")
with open(roc_file, "w") as f:
json.dump(roc_data, f)
print(f"整体ROC曲线数据已保存至: {roc_file}")
# 绘制ROC曲线
plt.figure(figsize=(8, 6))
plt.plot(overall_fpr, overall_tpr, label=f"Overall (AUC = {overall_auc:.4f})")
plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
plt.title('Overall ROC Curve', fontsize=14)
plt.legend(loc="lower right", fontsize=10)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'overall_roc_curve.png'), dpi=300, bbox_inches='tight')
print(f"整体ROC曲线图已保存至: {os.path.join(output_dir, 'overall_roc_curve.png')}")
if __name__ == '__main__':
main() |