DAminoMuta / vis /dataset_vis.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import sys
sys.path.append("..")
from dataset import PeptidePairDataset
from torch.utils.data import Subset
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import numpy as np
def extract_label(dataset: PeptidePairDataset):
labels = []
for _, label in dataset:
labels.append(label)
return labels
def extract_labels_cls(dataset: PeptidePairDataset):
labels = []
for _, label in dataset:
if label >= 0.5:
labels.append(2)
elif label <= -0.5:
labels.append(1)
else:
labels.append(0)
return labels
all_set = PeptidePairDataset(mode='train', pad_length=30, task='reg')
test_set = PeptidePairDataset(mode='r2_case', pad_length=30, task='reg', one_way=True)
# kf = KFold(n_splits=5, shuffle=True, random_state=42)
# train_idx, val_idx = next(kf.split(all_set))
# train_set = Subset(all_set, train_idx)
# valid_set = Subset(all_set, val_idx)
bins = np.arange(-6.5, 6.6, 0.5)
def plot_gradient_bar_histogram(ax, data, bins, cmap_name='Blues'):
"""使用渐变配色方案绘制直方图,每个柱子颜色不同但连续,并在柱子上添加数值标注"""
# 计算直方图数据
counts, bin_edges = np.histogram(data, bins=bins)
n_bins = len(counts)
# 获取 colormap
cmap = plt.get_cmap(cmap_name)
width = np.diff(bin_edges)[0] # 本例中每个柱子的宽度均为固定值
# 绘制每个柱子,并添加数值标注
for i in range(n_bins):
# 归一化的颜色比例
if n_bins > 1:
ratio = (i / (n_bins - 1)) * 0.5 + 0.3
else:
ratio = 0.5
color = cmap(ratio)
# 绘制柱子
ax.bar(bin_edges[i], counts[i], width=width, align='edge',
color=color, zorder=2)
# 为非空柱子添加数值标注
if counts[i] != 0:
x_pos = bin_edges[i] + width / 2
y_pos = counts[i]
ax.text(
x_pos, y_pos, f'{counts[i]}',
ha='center', va='bottom', fontweight='bold', fontsize=9
)
return counts, bin_edges
def plot_gradient_bar_histogram_cls(ax, data, cmap_name='Blues'):
"""将类别数据统计后绘制直方图,并添加数值标注,类别固定为 [0, 1, 2]"""
# 固定类别顺序
classes = np.array([0, 1, 2])
counts = np.array([data.count(cls) for cls in classes])
num_classes = len(classes)
width = 0.6
cmap = plt.get_cmap(cmap_name)
for i, cls in enumerate(classes):
# 渐变配色比例
ratio = (i / (num_classes - 1)) * 0.5 + 0.3 if num_classes > 1 else 0.5
color = cmap(ratio)
ax.bar(cls, counts[i], width=width, color=color, zorder=2)
if counts[i] != 0:
ax.text(cls, counts[i], f'{counts[i]}', ha='center', va='bottom', fontweight='bold')
ax.set_xticks(classes)
ax.set_xticklabels(['Unchanged', 'Lower', 'Higher'])
return counts
# # 设置全局字体大小
# plt.rcParams.update({
# 'font.size': 10,
# 'axes.titlesize': 14,
# 'axes.labelsize': 14,
# 'xtick.labelsize': 13,
# 'ytick.labelsize': 13,
# 'legend.fontsize': 14,
# 'figure.titlesize': 16
# })
plt.style.use('seaborn-v0_8-whitegrid')
# 创建 3 行 2 列子图,第一列为回归数据的可视化,第二列为分类数据的可视化
# fig, axes = plt.subplots(3, 2, figsize=(12, 9), width_ratios=[2.5, 1])
fig, axes = plt.subplots(1, 2, figsize=(11.4, 3))
# 数据提取:回归数据
reg_train = extract_label(all_set)
# reg_val = extract_label(valid_set)
reg_test = extract_label(test_set)
# 数据提取:分类数据
# cls_train = extract_labels_cls(all_set)
# cls_val = extract_labels_cls(valid_set)
# cls_test = extract_labels_cls(test_set)
# 第一行:训练集
counts, _ = plot_gradient_bar_histogram(axes[0], reg_train, bins, cmap_name='Blues')
axes[0].set_xlabel('log2 (MIC ratio)', size=14)
axes[0].set_xlim(-6.5, 6.5)
axes[0].set_ylabel('Counts', size=14)
axes[0].set_ylim(0, max(counts)*1.15)
axes[0].set_title(f'Train Set ({len(reg_train)})', weight='bold', size=16)
# counts_cls = plot_gradient_bar_histogram_cls(axes[0, 1], cls_train, cmap_name='Blues')
# axes[0, 1].set_xlabel('Class')
# axes[0, 1].set_xlim(-0.5, 2.5)
# axes[0, 1].set_ylabel('Counts')
# axes[0, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[0, 1].set_title(f'Train Set (Cls.) ({len(cls_train)})', weight='bold')
# 第二行:验证集
counts, _ = plot_gradient_bar_histogram(axes[1], reg_test, bins, cmap_name='Blues')
axes[1].set_xlabel('log2 (MIC ratio)', size=14)
axes[1].set_xlim(-6.5, 6.5)
axes[1].set_ylabel('Counts', size=14)
axes[1].set_ylim(0, max(counts)*1.15)
axes[1].set_title(f'R2 Set ({len(reg_test)})', weight='bold', size=16)
# counts_cls = plot_gradient_bar_histogram_cls(axes[1, 1], cls_val, cmap_name='Blues')
# axes[1, 1].set_xlabel('Class')
# axes[1, 1].set_xlim(-0.5, 2.5)
# axes[1, 1].set_ylabel('Counts')
# axes[1, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[1, 1].set_title(f'Validation Set (Cls.) ({len(cls_val)})', weight='bold')
# # 第三行:测试集
# counts, _ = plot_gradient_bar_histogram(axes[2, 0], reg_test, bins, cmap_name='Blues')
# axes[2, 0].set_xlabel('log2 (MIC ratio)')
# axes[2, 0].set_xlim(-6.5, 6.5)
# axes[2, 0].set_ylabel('Counts')
# axes[2, 0].set_ylim(0, max(counts)*1.15)
# axes[2, 0].set_title(f'Test Set (Reg.) ({len(reg_test)})', weight='bold')
# counts_cls = plot_gradient_bar_histogram_cls(axes[2, 1], cls_test, cmap_name='Blues')
# axes[2, 1].set_xlabel('Class')
# axes[2, 1].set_xlim(-0.5, 2.5)
# axes[2, 1].set_ylabel('Counts')
# axes[2, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[2, 1].set_title(f'Test Set (Cls.) ({len(cls_test)})', weight='bold')
plt.tight_layout(w_pad=3.2)
plt.savefig('dataset.svg')
plt.show()
# 输出数据集统计信息(回归数据)
print("\n训练集回归统计信息:")
print(f"样本总数: {len(reg_train)}")
print(f"范围: {min(reg_train)} - {max(reg_train)}")
# print("\n验证集回归统计信息:")
# print(f"样本总数: {len(reg_val)}")
# print(f"范围: {min(reg_val)} - {max(reg_val)}")
print("\n测试集回归统计信息:")
print(f"样本总数: {len(reg_test)}")
print(f"范围: {min(reg_test)} - {max(reg_test)}")
# 输出数据集统计信息(分类数据)
# print("\n训练集分类统计信息:")
# print(f"样本总数: {len(cls_train)}")
# print(f"各类别统计: 0类: {cls_train.count(0)}, 1类: {cls_train.count(1)}, 2类: {cls_train.count(2)}")
# print("\n验证集分类统计信息:")
# print(f"样本总数: {len(cls_val)}")
# print(f"各类别统计: 0类: {cls_val.count(0)}, 1类: {cls_val.count(1)}, 2类: {cls_val.count(2)}")
# print("\n测试集分类统计信息:")
# print(f"样本总数: {len(cls_test)}")
# print(f"各类别统计: 0类: {cls_test.count(0)}, 1类: {cls_test.count(1)}, 2类: {cls_test.count(2)}")