|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from umap import UMAP |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.colors import LinearSegmentedColormap |
|
|
|
|
|
|
|
|
|
|
|
base_path = '../run-cls/rn18-diff-16-mamba-pcs-768-ce-32-0.001-50/uda_r2' |
|
|
img_path = 'umap_after.png' |
|
|
|
|
|
|
|
|
def get_label_dict(): |
|
|
""" |
|
|
返回一个字典 { 描述(str): 标签(int, 0 或 1) } |
|
|
比如: |
|
|
return { |
|
|
"image_0001": 1, |
|
|
"image_0234": 0, |
|
|
# ... |
|
|
} |
|
|
""" |
|
|
df = pd.read_excel('../dataset/r2_case.xlsx') |
|
|
df = df[~pd.isna(df['TAG'])] |
|
|
|
|
|
tags = {} |
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
desc = row['SEQUENCE - D-type amino acid substitution'] |
|
|
tag = row['TAG'] |
|
|
if tag.lower() == 'improved': |
|
|
tags[desc] = 1 |
|
|
elif tag.lower() == 'not improved': |
|
|
tags[desc] = 0 |
|
|
|
|
|
return tags |
|
|
|
|
|
|
|
|
def get_score_dict(): |
|
|
""" |
|
|
返回 { 描述: Float in [0,1] }, |
|
|
用于背景点的颜色映射。 |
|
|
""" |
|
|
df = pd.read_csv(f'{base_path}/feature_preds.csv') |
|
|
desc = df['seq'].values |
|
|
scores = df['pred'].values |
|
|
minimum, maximum = np.min(scores), np.max(scores) |
|
|
scores = (scores - minimum) / (maximum - minimum) |
|
|
tags = { |
|
|
desc: score |
|
|
for desc, score in zip(desc, scores) |
|
|
} |
|
|
|
|
|
return tags |
|
|
|
|
|
|
|
|
data_dict = torch.load(f'{base_path}/features.pth', map_location='cpu') |
|
|
descs = list(data_dict.keys()) |
|
|
features = np.vstack([ |
|
|
data_dict[d].cpu().numpy() if isinstance(data_dict[d], torch.Tensor) |
|
|
else np.array(data_dict[d]) |
|
|
for d in descs |
|
|
]) |
|
|
print(f"共 {features.shape[0]} 个样本,特征维度 {features.shape[1]}") |
|
|
|
|
|
|
|
|
label_dict = get_label_dict() |
|
|
score_dict = get_score_dict() |
|
|
|
|
|
|
|
|
umap = UMAP(n_components=2, metric='euclidean') |
|
|
points_2d = umap.fit_transform(features) |
|
|
|
|
|
|
|
|
idx1 = [i for i,d in enumerate(descs) if d in label_dict and label_dict[d]==1] |
|
|
idx0 = [i for i,d in enumerate(descs) if d in label_dict and label_dict[d]==0] |
|
|
idx_rest = [i for i,d in enumerate(descs) if d not in label_dict] |
|
|
|
|
|
|
|
|
|
|
|
cmap = LinearSegmentedColormap.from_list('bg_cmap', ["#6EB1EC", "#E69C98"]) |
|
|
|
|
|
bg_scores = np.array([ score_dict.get(descs[i], 0.5) for i in idx_rest ]) |
|
|
bg_colors = cmap(bg_scores) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(3,3)) |
|
|
|
|
|
|
|
|
plt.scatter( |
|
|
points_2d[idx_rest,0], |
|
|
points_2d[idx_rest,1], |
|
|
s=20, |
|
|
c=bg_colors, |
|
|
alpha=0.8, |
|
|
label='Background' |
|
|
) |
|
|
|
|
|
|
|
|
plt.scatter( |
|
|
points_2d[idx0,0], |
|
|
points_2d[idx0,1], |
|
|
s=60, |
|
|
c="#218BE7", |
|
|
alpha=1.0, |
|
|
label='Not Improved' |
|
|
) |
|
|
|
|
|
|
|
|
plt.scatter( |
|
|
points_2d[idx1,0], |
|
|
points_2d[idx1,1], |
|
|
s=60, |
|
|
c="#E76760", |
|
|
alpha=1.0, |
|
|
label='Improved' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
plt.savefig(img_path, bbox_inches='tight', pad_inches=0, dpi=300) |
|
|
|
|
|
|