kaveh's picture
init
ef814bf
import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from models import SingleTransformer
from utils.helpers import create_multimodal_model
from data.create_dataset import MultiModalDataset
from .attentions import filter_idx
def get_latent_space(id, fold_results, labelled_dataset,
model_config, device, batch_size=32, common_samples=True):
if id not in ['RNA', 'ATAC', 'Flux', 'Multi']:
raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'")
latent_space = []
labels = []
preds = []
for fold in fold_results:
model_path = fold['best_model_path']
val_idx = fold['val_idx']
if common_samples:
val_idx = filter_idx(labelled_dataset, val_idx)
val_ds = Subset(labelled_dataset, val_idx)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
if id=='Multi':
model = create_multimodal_model(model_config, device, use_mlm=False)
else:
model = SingleTransformer(id=id, **model_config).to(device)
# Load weights to CPU first, then move to target device (handles CUDA->MPS/CPU transfer)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
with torch.no_grad():
for batch in val_loader:
x, b, y = batch
if isinstance(x, list):
rna= x[0].to(device)
atac = x[1].to(device)
flux = x[2].to(device)
x = (rna, atac, flux)
else:
x = x.to(device)
b = b.to(device)
ls, pred = model.get_latent_space(x, b)
latent_space.append(ls.cpu().numpy())
labels.append(y.numpy())
preds.append(pred.cpu().numpy())
latent_space = np.concatenate(latent_space)
labels = np.concatenate(labels)
preds = np.concatenate(preds)
preds = np.round(preds)
return latent_space, labels, preds
def get_latent_space_cached(models, fold_results, dataset, device, batch_size=64, common_samples=True):
"""
Compute latent space using preloaded models.
"""
latent_space = []
labels = []
preds = []
for model, fold in zip(models, fold_results):
val_idx = fold['val_idx']
if common_samples:
val_idx = filter_idx(dataset, val_idx)
val_ds = Subset(dataset, val_idx)
# Increase batch size to speed up inference
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
model.eval()
with torch.no_grad():
for batch in val_loader:
x, b, y = batch
if isinstance(x, list):
# For multimodal inputs, move each modality to device
rna = x[0].to(device)
atac = x[1].to(device)
flux = x[2].to(device)
x = (rna, atac, flux)
else:
x = x.to(device)
b = b.to(device)
ls, pred = model.get_latent_space(x, b)
latent_space.append(ls.cpu().numpy())
labels.append(y.numpy())
preds.append(pred.cpu().numpy())
latent_space = np.concatenate(latent_space)
labels = np.concatenate(labels)
preds = np.concatenate(preds)
preds = np.round(preds)
return latent_space, labels, preds
def measure_shift(original_latent, perturbed_latent):
return np.mean(np.linalg.norm(original_latent - perturbed_latent, axis=1))
def perturb_feature(data, feature_idx, perturbation_type='additive', scale=0.1, min_samples_threshold=10):
perturbed_data = data.clone()
non_zero_rows_mask = data[:, feature_idx] != 0
# Check if feature has enough non-zero samples
if non_zero_rows_mask.sum() < min_samples_threshold:
return None, True # Return None and flag indicating insufficient samples
if perturbation_type == 'shuffle':
# Shuffle only non-zero values (preserves sparsity pattern)
non_zero_values = perturbed_data[non_zero_rows_mask, feature_idx].clone()
shuffled_idx = torch.randperm(non_zero_values.size(0), device=perturbed_data.device)
perturbed_data[non_zero_rows_mask, feature_idx] = non_zero_values[shuffled_idx]
elif perturbation_type == 'shuffle_all':
# Shuffle all values (including zeros)
shuffled_idx = torch.randperm(perturbed_data.size(0), device=perturbed_data.device)
perturbed_data[:, feature_idx] = data[shuffled_idx, feature_idx]
elif perturbation_type == 'additive':
noise = torch.randn_like(perturbed_data[:, feature_idx].float()) * scale * torch.std(perturbed_data[:, feature_idx].float())
noise = noise.to(perturbed_data.device)
if data.dtype == torch.int32:
perturbed_data[non_zero_rows_mask, feature_idx] += torch.tensor(noise[non_zero_rows_mask], dtype=torch.int32).to(perturbed_data.device)
else:
perturbed_data[non_zero_rows_mask, feature_idx] += noise[non_zero_rows_mask]
elif perturbation_type == 'multiplicative':
factor = 1 + scale * (torch.rand(perturbed_data.shape[0], device=perturbed_data.device) - 0.5)
if data.dtype == torch.int32:
perturbed_data[non_zero_rows_mask, feature_idx] = torch.tensor(
perturbed_data[non_zero_rows_mask, feature_idx].float() * factor[non_zero_rows_mask],
dtype=torch.int32).to(perturbed_data.device)
else:
perturbed_data[non_zero_rows_mask, feature_idx] *= factor[non_zero_rows_mask]
return perturbed_data, False # Return perturbed data and flag indicating sufficient samples
def analyze_feature_importance_multi(id, model_config, fold_results, dataset, feature_names,
device, analyse_features='all', perturbation_scale=0.1, min_samples_threshold=10, common_samples=True):
if analyse_features not in ['all', 'RNA', 'ATAC', 'Flux']:
raise ValueError("analyse_features must be one of 'all', 'RNA', 'ATAC', 'Flux'")
models = []
for fold in fold_results:
model_path = fold['best_model_path']
if id == 'Multi':
model = create_multimodal_model(model_config, device, use_mlm=False)
else:
model = SingleTransformer(id=id, **model_config).to(device)
# Load weights to CPU first, then move to target device (handles CUDA->MPS/CPU transfer)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
models.append(model)
# Compute the original latent space once using the cached models
original_latent, _, _ = get_latent_space_cached(models, fold_results, dataset, device, batch_size=64, common_samples=common_samples)
feature_shifts = []
skipped_features = [] # Track features skipped due to insufficient samples
# Unpack multi-modal data
X, b, y = (dataset.rna_data, dataset.atac_data, dataset.flux_data), dataset.batch_no, dataset.labels
rna_input, atac_input, flux_input = X[0], X[1], X[2]
atac_start = rna_input.shape[1] + 1
flux_start = atac_start + atac_input.shape[1] + 1
print("atac start", atac_start, "flux start", flux_start)
perturb_type = 'shuffle'
if analyse_features in ['RNA', 'all']:
print("Analyzing RNA features")
print("Permuting RNA features with", perturb_type)
for i in tqdm(range(rna_input.shape[1])):
# Choose perturbation type based on the mean value
#if rna_input[:, i].float().mean() < 10 else 'multiplicative'
perturbed_rna, insufficient_samples = perturb_feature(rna_input, i, perturb_type, scale=perturbation_scale, min_samples_threshold=min_samples_threshold)
if insufficient_samples:
skipped_features.append((feature_names[i], "RNA", (rna_input[:, i] != 0).sum().item()))
feature_shifts.append((feature_names[i], 0.0)) # Add with 0 importance
else:
perturbed_dataset = MultiModalDataset((perturbed_rna, atac_input, flux_input), b, y)
perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples)
shift = measure_shift(original_latent, perturbed_latent)
feature_shifts.append((feature_names[i], shift))
if analyse_features in ['ATAC', 'all']:
print("Analyzing ATAC features")
print("Permuting ATAC features with", perturb_type)
for i in tqdm(range(atac_input.shape[1])):
perturbed_atac, insufficient_samples = perturb_feature(atac_input, i, perturb_type, perturbation_scale, min_samples_threshold=min_samples_threshold)
if insufficient_samples:
skipped_features.append((feature_names[atac_start + i], "ATAC", (atac_input[:, i] != 0).sum().item()))
feature_shifts.append((feature_names[atac_start + i], 0.0)) # Add with 0 importance
else:
perturbed_dataset = MultiModalDataset((rna_input, perturbed_atac, flux_input), b, y)
perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples)
shift = measure_shift(original_latent, perturbed_latent)
feature_shifts.append((feature_names[atac_start + i], shift))
if analyse_features in ['Flux', 'all']:
print("Permuting Flux features with", perturb_type)
print("Analyzing Flux features")
for i in tqdm(range(flux_input.shape[1])):
perturbed_flux, insufficient_samples = perturb_feature(flux_input, i, 'shuffle_all', perturbation_scale, min_samples_threshold=min_samples_threshold)
if insufficient_samples:
skipped_features.append((feature_names[flux_start + i], "Flux", (flux_input[:, i] != 0).sum().item()))
feature_shifts.append((feature_names[flux_start + i], 0.0)) # Add with 0 importance
else:
perturbed_dataset = MultiModalDataset((rna_input, atac_input, perturbed_flux), b, y)
perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples)
shift = measure_shift(original_latent, perturbed_latent)
feature_shifts.append((feature_names[flux_start + i], shift))
# Log skipped features
if skipped_features:
print(f"\nSkipped {len(skipped_features)} features due to insufficient samples (< {min_samples_threshold}):")
for feature_name, modality, sample_count in skipped_features:
print(f" {feature_name} ({modality}): {sample_count} samples")
return sorted(feature_shifts, key=lambda x: x[1], reverse=True)