Spaces:
Running
Running
| import torch | |
| from torch.utils.data import DataLoader, Subset | |
| import numpy as np | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from models import SingleTransformer | |
| from utils.helpers import create_multimodal_model | |
| from data.create_dataset import MultiModalDataset | |
| from .attentions import filter_idx | |
| def get_latent_space(id, fold_results, labelled_dataset, | |
| model_config, device, batch_size=32, common_samples=True): | |
| if id not in ['RNA', 'ATAC', 'Flux', 'Multi']: | |
| raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'") | |
| latent_space = [] | |
| labels = [] | |
| preds = [] | |
| for fold in fold_results: | |
| model_path = fold['best_model_path'] | |
| val_idx = fold['val_idx'] | |
| if common_samples: | |
| val_idx = filter_idx(labelled_dataset, val_idx) | |
| val_ds = Subset(labelled_dataset, val_idx) | |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) | |
| if id=='Multi': | |
| model = create_multimodal_model(model_config, device, use_mlm=False) | |
| else: | |
| model = SingleTransformer(id=id, **model_config).to(device) | |
| # Load weights to CPU first, then move to target device (handles CUDA->MPS/CPU transfer) | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| x, b, y = batch | |
| if isinstance(x, list): | |
| rna= x[0].to(device) | |
| atac = x[1].to(device) | |
| flux = x[2].to(device) | |
| x = (rna, atac, flux) | |
| else: | |
| x = x.to(device) | |
| b = b.to(device) | |
| ls, pred = model.get_latent_space(x, b) | |
| latent_space.append(ls.cpu().numpy()) | |
| labels.append(y.numpy()) | |
| preds.append(pred.cpu().numpy()) | |
| latent_space = np.concatenate(latent_space) | |
| labels = np.concatenate(labels) | |
| preds = np.concatenate(preds) | |
| preds = np.round(preds) | |
| return latent_space, labels, preds | |
| def get_latent_space_cached(models, fold_results, dataset, device, batch_size=64, common_samples=True): | |
| """ | |
| Compute latent space using preloaded models. | |
| """ | |
| latent_space = [] | |
| labels = [] | |
| preds = [] | |
| for model, fold in zip(models, fold_results): | |
| val_idx = fold['val_idx'] | |
| if common_samples: | |
| val_idx = filter_idx(dataset, val_idx) | |
| val_ds = Subset(dataset, val_idx) | |
| # Increase batch size to speed up inference | |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) | |
| model.eval() | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| x, b, y = batch | |
| if isinstance(x, list): | |
| # For multimodal inputs, move each modality to device | |
| rna = x[0].to(device) | |
| atac = x[1].to(device) | |
| flux = x[2].to(device) | |
| x = (rna, atac, flux) | |
| else: | |
| x = x.to(device) | |
| b = b.to(device) | |
| ls, pred = model.get_latent_space(x, b) | |
| latent_space.append(ls.cpu().numpy()) | |
| labels.append(y.numpy()) | |
| preds.append(pred.cpu().numpy()) | |
| latent_space = np.concatenate(latent_space) | |
| labels = np.concatenate(labels) | |
| preds = np.concatenate(preds) | |
| preds = np.round(preds) | |
| return latent_space, labels, preds | |
| def measure_shift(original_latent, perturbed_latent): | |
| return np.mean(np.linalg.norm(original_latent - perturbed_latent, axis=1)) | |
| def perturb_feature(data, feature_idx, perturbation_type='additive', scale=0.1, min_samples_threshold=10): | |
| perturbed_data = data.clone() | |
| non_zero_rows_mask = data[:, feature_idx] != 0 | |
| # Check if feature has enough non-zero samples | |
| if non_zero_rows_mask.sum() < min_samples_threshold: | |
| return None, True # Return None and flag indicating insufficient samples | |
| if perturbation_type == 'shuffle': | |
| # Shuffle only non-zero values (preserves sparsity pattern) | |
| non_zero_values = perturbed_data[non_zero_rows_mask, feature_idx].clone() | |
| shuffled_idx = torch.randperm(non_zero_values.size(0), device=perturbed_data.device) | |
| perturbed_data[non_zero_rows_mask, feature_idx] = non_zero_values[shuffled_idx] | |
| elif perturbation_type == 'shuffle_all': | |
| # Shuffle all values (including zeros) | |
| shuffled_idx = torch.randperm(perturbed_data.size(0), device=perturbed_data.device) | |
| perturbed_data[:, feature_idx] = data[shuffled_idx, feature_idx] | |
| elif perturbation_type == 'additive': | |
| noise = torch.randn_like(perturbed_data[:, feature_idx].float()) * scale * torch.std(perturbed_data[:, feature_idx].float()) | |
| noise = noise.to(perturbed_data.device) | |
| if data.dtype == torch.int32: | |
| perturbed_data[non_zero_rows_mask, feature_idx] += torch.tensor(noise[non_zero_rows_mask], dtype=torch.int32).to(perturbed_data.device) | |
| else: | |
| perturbed_data[non_zero_rows_mask, feature_idx] += noise[non_zero_rows_mask] | |
| elif perturbation_type == 'multiplicative': | |
| factor = 1 + scale * (torch.rand(perturbed_data.shape[0], device=perturbed_data.device) - 0.5) | |
| if data.dtype == torch.int32: | |
| perturbed_data[non_zero_rows_mask, feature_idx] = torch.tensor( | |
| perturbed_data[non_zero_rows_mask, feature_idx].float() * factor[non_zero_rows_mask], | |
| dtype=torch.int32).to(perturbed_data.device) | |
| else: | |
| perturbed_data[non_zero_rows_mask, feature_idx] *= factor[non_zero_rows_mask] | |
| return perturbed_data, False # Return perturbed data and flag indicating sufficient samples | |
| def analyze_feature_importance_multi(id, model_config, fold_results, dataset, feature_names, | |
| device, analyse_features='all', perturbation_scale=0.1, min_samples_threshold=10, common_samples=True): | |
| if analyse_features not in ['all', 'RNA', 'ATAC', 'Flux']: | |
| raise ValueError("analyse_features must be one of 'all', 'RNA', 'ATAC', 'Flux'") | |
| models = [] | |
| for fold in fold_results: | |
| model_path = fold['best_model_path'] | |
| if id == 'Multi': | |
| model = create_multimodal_model(model_config, device, use_mlm=False) | |
| else: | |
| model = SingleTransformer(id=id, **model_config).to(device) | |
| # Load weights to CPU first, then move to target device (handles CUDA->MPS/CPU transfer) | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| models.append(model) | |
| # Compute the original latent space once using the cached models | |
| original_latent, _, _ = get_latent_space_cached(models, fold_results, dataset, device, batch_size=64, common_samples=common_samples) | |
| feature_shifts = [] | |
| skipped_features = [] # Track features skipped due to insufficient samples | |
| # Unpack multi-modal data | |
| X, b, y = (dataset.rna_data, dataset.atac_data, dataset.flux_data), dataset.batch_no, dataset.labels | |
| rna_input, atac_input, flux_input = X[0], X[1], X[2] | |
| atac_start = rna_input.shape[1] + 1 | |
| flux_start = atac_start + atac_input.shape[1] + 1 | |
| print("atac start", atac_start, "flux start", flux_start) | |
| perturb_type = 'shuffle' | |
| if analyse_features in ['RNA', 'all']: | |
| print("Analyzing RNA features") | |
| print("Permuting RNA features with", perturb_type) | |
| for i in tqdm(range(rna_input.shape[1])): | |
| # Choose perturbation type based on the mean value | |
| #if rna_input[:, i].float().mean() < 10 else 'multiplicative' | |
| perturbed_rna, insufficient_samples = perturb_feature(rna_input, i, perturb_type, scale=perturbation_scale, min_samples_threshold=min_samples_threshold) | |
| if insufficient_samples: | |
| skipped_features.append((feature_names[i], "RNA", (rna_input[:, i] != 0).sum().item())) | |
| feature_shifts.append((feature_names[i], 0.0)) # Add with 0 importance | |
| else: | |
| perturbed_dataset = MultiModalDataset((perturbed_rna, atac_input, flux_input), b, y) | |
| perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples) | |
| shift = measure_shift(original_latent, perturbed_latent) | |
| feature_shifts.append((feature_names[i], shift)) | |
| if analyse_features in ['ATAC', 'all']: | |
| print("Analyzing ATAC features") | |
| print("Permuting ATAC features with", perturb_type) | |
| for i in tqdm(range(atac_input.shape[1])): | |
| perturbed_atac, insufficient_samples = perturb_feature(atac_input, i, perturb_type, perturbation_scale, min_samples_threshold=min_samples_threshold) | |
| if insufficient_samples: | |
| skipped_features.append((feature_names[atac_start + i], "ATAC", (atac_input[:, i] != 0).sum().item())) | |
| feature_shifts.append((feature_names[atac_start + i], 0.0)) # Add with 0 importance | |
| else: | |
| perturbed_dataset = MultiModalDataset((rna_input, perturbed_atac, flux_input), b, y) | |
| perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples) | |
| shift = measure_shift(original_latent, perturbed_latent) | |
| feature_shifts.append((feature_names[atac_start + i], shift)) | |
| if analyse_features in ['Flux', 'all']: | |
| print("Permuting Flux features with", perturb_type) | |
| print("Analyzing Flux features") | |
| for i in tqdm(range(flux_input.shape[1])): | |
| perturbed_flux, insufficient_samples = perturb_feature(flux_input, i, 'shuffle_all', perturbation_scale, min_samples_threshold=min_samples_threshold) | |
| if insufficient_samples: | |
| skipped_features.append((feature_names[flux_start + i], "Flux", (flux_input[:, i] != 0).sum().item())) | |
| feature_shifts.append((feature_names[flux_start + i], 0.0)) # Add with 0 importance | |
| else: | |
| perturbed_dataset = MultiModalDataset((rna_input, atac_input, perturbed_flux), b, y) | |
| perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples) | |
| shift = measure_shift(original_latent, perturbed_latent) | |
| feature_shifts.append((feature_names[flux_start + i], shift)) | |
| # Log skipped features | |
| if skipped_features: | |
| print(f"\nSkipped {len(skipped_features)} features due to insufficient samples (< {min_samples_threshold}):") | |
| for feature_name, modality, sample_count in skipped_features: | |
| print(f" {feature_name} ({modality}): {sample_count} samples") | |
| return sorted(feature_shifts, key=lambda x: x[1], reverse=True) | |