| # #accuracy | |
| import json | |
| import re | |
| def compute_accuracy_from_file(filepath): | |
| total, correct = 0, 0 | |
| pattern = re.compile(r"<ANSWER>(.*?)</ANSWER>", re.IGNORECASE) | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| data = json.loads(line) | |
| print(data) | |
| tgt_match = data.get("reference_answer", "") | |
| print(tgt_match) | |
| pred_match = pattern.search(data.get("generated_answer", "")) | |
| #pred_match = data.get("generated_answer", "") | |
| print(pred_match) | |
| if not tgt_match or not pred_match: | |
| continue | |
| try: | |
| tgt_val = int(tgt_match.strip()) | |
| pred_val = int(pred_match.group(1).strip()) | |
| #.group(1). | |
| except ValueError: | |
| continue # 如果无法转换为int,跳过此条 | |
| total += 1 | |
| if pred_val == tgt_val: | |
| correct += 1 | |
| if total == 0: | |
| return 0.0, 0 | |
| return correct / total, total | |
| if __name__ == "__main__": | |
| filepath = "/nas/shared/kilab/wangyujia/BIO/ablation/temperature_stability.jsonl" | |
| acc, count = compute_accuracy_from_file(filepath) | |
| print(f"Checked {count} items. Accuracy: {acc*100:.3f}%") | |
| # #spearman | |
| # import json | |
| # import re | |
| # from scipy.stats import spearmanr | |
| # # 文件路径(替换为你的 JSONL 文件路径) | |
| # file_path = '/nas/shared/kilab/wangyujia/BIO/ablation/temperature_stability.jsonl' | |
| # # 正则:提取 <answer>...</answer> 中的数值 | |
| # pattern = re.compile(r"<answer>(.*?)</answer>") | |
| # y_true = [] # reference_answer 中的真实值 | |
| # y_pred = [] # generated_answer 预测值 | |
| # with open(file_path, 'r', encoding='utf-8') as f: | |
| # for line in f: | |
| # line = line.strip() | |
| # print(line) | |
| # if not line: | |
| # continue | |
| # try: | |
| # data = json.loads(line) # 每行都是一个 JSON 对象 | |
| # reference_str = data.get("reference_answer", "").strip() | |
| # generated_str = data.get("generated_answer", "").strip() | |
| # print(reference_str) | |
| # print(generated_str) | |
| # # 提取 generated_answer 中的数值 | |
| # pred_match = pattern.search(generated_str) | |
| # print(pred_match.group(1)) | |
| # if pred_match: | |
| # pred_value = float(pred_match.group(1)) # 提取 <answer> 中的值 | |
| # true_value = float(reference_str) # reference_answer 本身就是数值字符串 | |
| # y_true.append(true_value) | |
| # y_pred.append(pred_value) | |
| # else: | |
| # print(f"未找到 <answer> 标签,跳过:{generated_str}") | |
| # except Exception as e: | |
| # print(f"处理行时出错:{line}") | |
| # print(f"错误:{e}") | |
| # continue | |
| # # 计算 Spearman 相关系数 | |
| # if len(y_true) > 1: | |
| # print(len(y_true)) | |
| # print(len(y_pred)) | |
| # rho, p_value = spearmanr(y_pred,y_true) | |
| # print(f"有效样本数:{len(y_true)}") | |
| # print(f"Spearman correlation coefficient: {rho:.5f}, p-value: {p_value:.4e}") | |
| # else: | |
| # print("有效数据不足,无法计算 Spearman 相关系数。") | |
| # # f1 | |
| # import re | |
| # import json | |
| # def extract_numbers_from_generated_answer(s): | |
| # """从<answer>...</answer>中提取纯数字""" | |
| # match = re.search(r'<answer>(.*?)</answer>', s) | |
| # if match: | |
| # numbers_str = match.group(1) | |
| # numbers = re.findall(r'\d+', numbers_str) | |
| # return set(map(int, numbers)) | |
| # else: | |
| # return set() | |
| # def extract_numbers_from_reference_answer(s): | |
| # """从参考答案字符串中提取数字,忽略非数字字符""" | |
| # numbers = re.findall(r'\d+', s) | |
| # return set(map(int, numbers)) | |
| # def calculate_f1(pred_set, target_set): | |
| # if not pred_set and not target_set: | |
| # return 1.0 | |
| # if not pred_set or not target_set: | |
| # return 0.0 | |
| # tp = len(pred_set & target_set) | |
| # fp = len(pred_set - target_set) | |
| # fn = len(target_set - pred_set) | |
| # precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| # recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| # if precision + recall == 0: | |
| # return 0.0 | |
| # return 2 * precision * recall / (precision + recall) | |
| # # 文件路径 | |
| # file_path = '/nas/shared/kilab/wangyujia/BIO/ablation/enzyme_commission_number.jsonl' | |
| # all_f1_scores = [] | |
| # with open(file_path, 'r', encoding='utf-8') as f: | |
| # for line in f: | |
| # try: | |
| # data = json.loads(line) | |
| # reference = data.get('reference_answer', '') | |
| # generated = data.get('generated_answer', '') | |
| # target_set = extract_numbers_from_reference_answer(reference) | |
| # pred_set = extract_numbers_from_generated_answer(generated) | |
| # f1 = calculate_f1(pred_set, target_set) | |
| # all_f1_scores.append(f1) | |
| # except Exception as e: | |
| # print(f"处理行时出错:{line.strip()}") | |
| # print(f"错误:{e}") | |
| # # 输出结果 | |
| # if all_f1_scores: | |
| # avg_f1 = sum(all_f1_scores) / len(all_f1_scores) | |
| # print(f"总样本数:{len(all_f1_scores)}") | |
| # print(f"平均 F1 分数:{avg_f1:.5f}") | |
| # else: | |
| # print("没有有效的F1分数可供计算") | |