|
import sys |
|
import os |
|
import traceback |
|
import json |
|
import pickle |
|
import numpy as np |
|
import scanpy as sc |
|
import pandas as pd |
|
import networkx as nx |
|
from tqdm import tqdm |
|
import logging |
|
import torch |
|
import torch.optim as optim |
|
import torch.nn as nn |
|
from sklearn.metrics import r2_score |
|
from torch.optim.lr_scheduler import StepLR |
|
from torch_geometric.nn import SGConv |
|
from copy import deepcopy |
|
from torch_geometric.data import Data, DataLoader |
|
from multiprocessing import Pool |
|
from torch.nn import Sequential, Linear, ReLU |
|
from scipy.stats import pearsonr |
|
from sklearn.metrics import mean_squared_error as mse |
|
from sklearn.metrics import mean_absolute_error as mae |
|
|
|
class MLP(torch.nn.Module): |
|
|
|
def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): |
|
super(MLP, self).__init__() |
|
layers = [] |
|
for s in range(len(sizes) - 1): |
|
layers = layers + [ |
|
torch.nn.Linear(sizes[s], sizes[s + 1]), |
|
torch.nn.BatchNorm1d(sizes[s + 1]) |
|
if batch_norm and s < len(sizes) - 1 else None, |
|
torch.nn.ReLU() |
|
] |
|
|
|
layers = [l for l in layers if l is not None][:-1] |
|
self.activation = last_layer_act |
|
self.network = torch.nn.Sequential(*layers) |
|
self.relu = torch.nn.ReLU() |
|
def forward(self, x): |
|
return self.network(x) |
|
|
|
|
|
class GEARS_Model(torch.nn.Module): |
|
""" |
|
GEARS model |
|
|
|
""" |
|
|
|
def __init__(self, args): |
|
""" |
|
:param args: arguments dictionary |
|
""" |
|
|
|
super(GEARS_Model, self).__init__() |
|
self.args = args |
|
self.num_genes = args['num_genes'] |
|
self.num_perts = args['num_perts'] |
|
hidden_size = args['hidden_size'] |
|
self.uncertainty = args['uncertainty'] |
|
self.num_layers = args['num_go_gnn_layers'] |
|
self.indv_out_hidden_size = args['decoder_hidden_size'] |
|
self.num_layers_gene_pos = args['num_gene_gnn_layers'] |
|
self.no_perturb = args['no_perturb'] |
|
self.pert_emb_lambda = 0.2 |
|
|
|
|
|
self.pert_w = nn.Linear(1, hidden_size) |
|
|
|
|
|
self.gene_emb = nn.Embedding(self.num_genes, hidden_size, max_norm=True) |
|
self.pert_emb = nn.Embedding(self.num_perts, hidden_size, max_norm=True) |
|
|
|
|
|
self.emb_trans = nn.ReLU() |
|
self.pert_base_trans = nn.ReLU() |
|
self.transform = nn.ReLU() |
|
self.emb_trans_v2 = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') |
|
self.pert_fuse = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') |
|
|
|
|
|
self.G_coexpress = args['G_coexpress'].to(args['device']) |
|
self.G_coexpress_weight = args['G_coexpress_weight'].to(args['device']) |
|
|
|
self.emb_pos = nn.Embedding(self.num_genes, hidden_size, max_norm=True) |
|
self.layers_emb_pos = torch.nn.ModuleList() |
|
for i in range(1, self.num_layers_gene_pos + 1): |
|
self.layers_emb_pos.append(SGConv(hidden_size, hidden_size, 1)) |
|
|
|
|
|
self.G_sim = args['G_go'].to(args['device']) |
|
self.G_sim_weight = args['G_go_weight'].to(args['device']) |
|
|
|
self.sim_layers = torch.nn.ModuleList() |
|
for i in range(1, self.num_layers + 1): |
|
self.sim_layers.append(SGConv(hidden_size, hidden_size, 1)) |
|
|
|
|
|
self.recovery_w = MLP([hidden_size, hidden_size*2, hidden_size], last_layer_act='linear') |
|
|
|
|
|
self.indv_w1 = nn.Parameter(torch.rand(self.num_genes, |
|
hidden_size, 1)) |
|
self.indv_b1 = nn.Parameter(torch.rand(self.num_genes, 1)) |
|
self.act = nn.ReLU() |
|
nn.init.xavier_normal_(self.indv_w1) |
|
nn.init.xavier_normal_(self.indv_b1) |
|
|
|
|
|
self.cross_gene_state = MLP([self.num_genes, hidden_size, |
|
hidden_size]) |
|
|
|
self.indv_w2 = nn.Parameter(torch.rand(1, self.num_genes, |
|
hidden_size+1)) |
|
self.indv_b2 = nn.Parameter(torch.rand(1, self.num_genes)) |
|
nn.init.xavier_normal_(self.indv_w2) |
|
nn.init.xavier_normal_(self.indv_b2) |
|
|
|
|
|
self.bn_emb = nn.BatchNorm1d(hidden_size) |
|
self.bn_pert_base = nn.BatchNorm1d(hidden_size) |
|
self.bn_pert_base_trans = nn.BatchNorm1d(hidden_size) |
|
|
|
|
|
if self.uncertainty: |
|
self.uncertainty_w = MLP([hidden_size, hidden_size*2, hidden_size, 1], last_layer_act='linear') |
|
|
|
def forward(self, data): |
|
""" |
|
Forward pass of the model |
|
""" |
|
x, pert_idx = data.x, data.pert_idx |
|
if self.no_perturb: |
|
out = x.reshape(-1,1) |
|
out = torch.split(torch.flatten(out), self.num_genes) |
|
return torch.stack(out) |
|
else: |
|
num_graphs = len(data.batch.unique()) |
|
|
|
|
|
emb = self.gene_emb(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) |
|
emb = self.bn_emb(emb) |
|
base_emb = self.emb_trans(emb) |
|
|
|
pos_emb = self.emb_pos(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) |
|
for idx, layer in enumerate(self.layers_emb_pos): |
|
pos_emb = layer(pos_emb, self.G_coexpress, self.G_coexpress_weight) |
|
if idx < len(self.layers_emb_pos) - 1: |
|
pos_emb = pos_emb.relu() |
|
|
|
base_emb = base_emb + 0.2 * pos_emb |
|
base_emb = self.emb_trans_v2(base_emb) |
|
|
|
|
|
|
|
pert_index = [] |
|
for idx, i in enumerate(pert_idx): |
|
for j in i: |
|
if j != -1: |
|
pert_index.append([idx, j]) |
|
pert_index = torch.tensor(pert_index).T |
|
|
|
pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.args['device'])) |
|
|
|
|
|
for idx, layer in enumerate(self.sim_layers): |
|
pert_global_emb = layer(pert_global_emb, self.G_sim, self.G_sim_weight) |
|
if idx < self.num_layers - 1: |
|
pert_global_emb = pert_global_emb.relu() |
|
|
|
|
|
base_emb = base_emb.reshape(num_graphs, self.num_genes, -1) |
|
|
|
if pert_index.shape[0] != 0: |
|
|
|
pert_track = {} |
|
for i, j in enumerate(pert_index[0]): |
|
if j.item() in pert_track: |
|
pert_track[j.item()] = pert_track[j.item()] + pert_global_emb[pert_index[1][i]] |
|
else: |
|
pert_track[j.item()] = pert_global_emb[pert_index[1][i]] |
|
|
|
if len(list(pert_track.values())) > 0: |
|
if len(list(pert_track.values())) == 1: |
|
|
|
emb_total = self.pert_fuse(torch.stack(list(pert_track.values()) * 2)) |
|
else: |
|
emb_total = self.pert_fuse(torch.stack(list(pert_track.values()))) |
|
|
|
for idx, j in enumerate(pert_track.keys()): |
|
base_emb[j] = base_emb[j] + emb_total[idx] |
|
|
|
base_emb = base_emb.reshape(num_graphs * self.num_genes, -1) |
|
base_emb = self.bn_pert_base(base_emb) |
|
|
|
|
|
base_emb = self.transform(base_emb) |
|
out = self.recovery_w(base_emb) |
|
out = out.reshape(num_graphs, self.num_genes, -1) |
|
out = out.unsqueeze(-1) * self.indv_w1 |
|
w = torch.sum(out, axis = 2) |
|
out = w + self.indv_b1 |
|
|
|
|
|
cross_gene_embed = self.cross_gene_state(out.reshape(num_graphs, self.num_genes, -1).squeeze(2)) |
|
cross_gene_embed = cross_gene_embed.repeat(1, self.num_genes) |
|
|
|
cross_gene_embed = cross_gene_embed.reshape([num_graphs,self.num_genes, -1]) |
|
cross_gene_out = torch.cat([out, cross_gene_embed], 2) |
|
|
|
cross_gene_out = cross_gene_out * self.indv_w2 |
|
cross_gene_out = torch.sum(cross_gene_out, axis=2) |
|
out = cross_gene_out + self.indv_b2 |
|
out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1) |
|
out = torch.split(torch.flatten(out), self.num_genes) |
|
|
|
|
|
if self.uncertainty: |
|
out_logvar = self.uncertainty_w(base_emb) |
|
out_logvar = torch.split(torch.flatten(out_logvar), self.num_genes) |
|
return torch.stack(out), torch.stack(out_logvar) |
|
|
|
return torch.stack(out) |
|
|
|
class GEARS: |
|
""" |
|
GEARS base model class |
|
""" |
|
|
|
def __init__(self, pert_data, |
|
device = 'cuda', |
|
weight_bias_track = True, |
|
proj_name = 'GEARS', |
|
exp_name = 'GEARS'): |
|
|
|
self.weight_bias_track = weight_bias_track |
|
|
|
if self.weight_bias_track: |
|
import wandb |
|
wandb.init(project=proj_name, name=exp_name) |
|
self.wandb = wandb |
|
else: |
|
self.wandb = None |
|
|
|
self.device = device |
|
self.config = None |
|
|
|
self.dataloader = pert_data.dataloader |
|
self.adata = pert_data.adata |
|
self.node_map = pert_data.node_map |
|
self.node_map_pert = pert_data.node_map_pert |
|
self.data_path = pert_data.data_path |
|
self.dataset_name = pert_data.dataset_name |
|
self.split = pert_data.split |
|
self.seed = pert_data.seed |
|
self.train_gene_set_size = pert_data.train_gene_set_size |
|
self.set2conditions = pert_data.set2conditions |
|
self.subgroup = pert_data.subgroup |
|
self.gene_list = pert_data.gene_names.values.tolist() |
|
self.pert_list = pert_data.pert_names.tolist() |
|
self.num_genes = len(self.gene_list) |
|
self.num_perts = len(self.pert_list) |
|
self.default_pert_graph = pert_data.default_pert_graph |
|
self.saved_pred = {} |
|
self.saved_logvar_sum = {} |
|
|
|
self.ctrl_expression = torch.tensor( |
|
np.mean(self.adata.X[self.adata.obs['condition'].values == 'ctrl'], |
|
axis=0)).reshape(-1, ).to(self.device) |
|
pert_full_id2pert = dict(self.adata.obs[['condition_name', 'condition']].values) |
|
self.dict_filter = {pert_full_id2pert[i]: j for i, j in |
|
self.adata.uns['non_zeros_gene_idx'].items() if |
|
i in pert_full_id2pert} |
|
self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] |
|
|
|
gene_dict = {g:i for i,g in enumerate(self.gene_list)} |
|
self.pert2gene = {p: gene_dict[pert] for p, pert in |
|
enumerate(self.pert_list) if pert in self.gene_list} |
|
|
|
def model_initialize(self, hidden_size = 64, |
|
num_go_gnn_layers = 1, |
|
num_gene_gnn_layers = 1, |
|
decoder_hidden_size = 16, |
|
num_similar_genes_go_graph = 20, |
|
num_similar_genes_co_express_graph = 20, |
|
coexpress_threshold = 0.4, |
|
uncertainty = False, |
|
uncertainty_reg = 1, |
|
direction_lambda = 1e-1, |
|
G_go = None, |
|
G_go_weight = None, |
|
G_coexpress = None, |
|
G_coexpress_weight = None, |
|
no_perturb = False, |
|
**kwargs |
|
): |
|
|
|
self.config = {'hidden_size': hidden_size, |
|
'num_go_gnn_layers' : num_go_gnn_layers, |
|
'num_gene_gnn_layers' : num_gene_gnn_layers, |
|
'decoder_hidden_size' : decoder_hidden_size, |
|
'num_similar_genes_go_graph' : num_similar_genes_go_graph, |
|
'num_similar_genes_co_express_graph' : num_similar_genes_co_express_graph, |
|
'coexpress_threshold': coexpress_threshold, |
|
'uncertainty' : uncertainty, |
|
'uncertainty_reg' : uncertainty_reg, |
|
'direction_lambda' : direction_lambda, |
|
'G_go': G_go, |
|
'G_go_weight': G_go_weight, |
|
'G_coexpress': G_coexpress, |
|
'G_coexpress_weight': G_coexpress_weight, |
|
'device': self.device, |
|
'num_genes': self.num_genes, |
|
'num_perts': self.num_perts, |
|
'no_perturb': no_perturb |
|
} |
|
|
|
if self.wandb: |
|
self.wandb.config.update(self.config) |
|
|
|
if self.config['G_coexpress'] is None: |
|
|
|
edge_list = get_similarity_network(network_type='co-express', |
|
adata=self.adata, |
|
threshold=coexpress_threshold, |
|
k=num_similar_genes_co_express_graph, |
|
data_path=self.data_path, |
|
data_name=self.dataset_name, |
|
split=self.split, seed=self.seed, |
|
train_gene_set_size=self.train_gene_set_size, |
|
set2conditions=self.set2conditions) |
|
|
|
sim_network = GeneSimNetwork(edge_list, self.gene_list, node_map = self.node_map) |
|
self.config['G_coexpress'] = sim_network.edge_index |
|
self.config['G_coexpress_weight'] = sim_network.edge_weight |
|
|
|
if self.config['G_go'] is None: |
|
|
|
edge_list = get_similarity_network(network_type='go', |
|
adata=self.adata, |
|
threshold=coexpress_threshold, |
|
k=num_similar_genes_go_graph, |
|
pert_list=self.pert_list, |
|
data_path=self.data_path, |
|
data_name=self.dataset_name, |
|
split=self.split, seed=self.seed, |
|
train_gene_set_size=self.train_gene_set_size, |
|
set2conditions=self.set2conditions, |
|
default_pert_graph=self.default_pert_graph) |
|
|
|
sim_network = GeneSimNetwork(edge_list, self.pert_list, node_map = self.node_map_pert) |
|
self.config['G_go'] = sim_network.edge_index |
|
self.config['G_go_weight'] = sim_network.edge_weight |
|
|
|
self.model = GEARS_Model(self.config).to(self.device) |
|
self.best_model = deepcopy(self.model) |
|
|
|
def load_pretrained(self, path): |
|
|
|
with open(os.path.join(path, 'config.pkl'), 'rb') as f: |
|
config = pickle.load(f) |
|
|
|
del config['device'], config['num_genes'], config['num_perts'] |
|
self.model_initialize(**config) |
|
self.config = config |
|
|
|
state_dict = torch.load(os.path.join(path, 'model.pt'), map_location = torch.device('cpu')) |
|
if next(iter(state_dict))[:7] == 'module.': |
|
|
|
from collections import OrderedDict |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
name = k[7:] |
|
new_state_dict[name] = v |
|
state_dict = new_state_dict |
|
|
|
self.model.load_state_dict(state_dict) |
|
self.model = self.model.to(self.device) |
|
self.best_model = self.model |
|
|
|
def save_model(self, path): |
|
if not os.path.exists(path): |
|
os.mkdir(path) |
|
|
|
if self.config is None: |
|
raise ValueError('No model is initialized...') |
|
|
|
with open(os.path.join(path, 'config.pkl'), 'wb') as f: |
|
pickle.dump(self.config, f) |
|
|
|
torch.save(self.best_model.state_dict(), os.path.join(path, 'model.pt')) |
|
|
|
|
|
def train(self, epochs = 20, |
|
lr = 1e-3, |
|
weight_decay = 5e-4 |
|
): |
|
""" |
|
Train the model |
|
|
|
Parameters |
|
---------- |
|
epochs: int |
|
number of epochs to train |
|
lr: float |
|
learning rate |
|
weight_decay: float |
|
weight decay |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
""" |
|
|
|
train_loader = self.dataloader['train_loader'] |
|
val_loader = self.dataloader['val_loader'] |
|
|
|
self.model = self.model.to(self.device) |
|
best_model = deepcopy(self.model) |
|
optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay = weight_decay) |
|
scheduler = StepLR(optimizer, step_size=1, gamma=0.5) |
|
|
|
min_val = np.inf |
|
print_sys('Start Training...') |
|
|
|
for epoch in range(epochs): |
|
self.model.train() |
|
|
|
for step, batch in enumerate(train_loader): |
|
batch.to(self.device) |
|
optimizer.zero_grad() |
|
y = batch.y |
|
if self.config['uncertainty']: |
|
pred, logvar = self.model(batch) |
|
loss = uncertainty_loss_fct(pred, logvar, y, batch.pert, |
|
reg = self.config['uncertainty_reg'], |
|
ctrl = self.ctrl_expression, |
|
dict_filter = self.dict_filter, |
|
direction_lambda = self.config['direction_lambda']) |
|
else: |
|
pred = self.model(batch) |
|
loss = loss_fct(pred, y, batch.pert, |
|
ctrl = self.ctrl_expression, |
|
dict_filter = self.dict_filter, |
|
direction_lambda = self.config['direction_lambda']) |
|
loss.backward() |
|
nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0) |
|
optimizer.step() |
|
|
|
if self.wandb: |
|
self.wandb.log({'training_loss': loss.item()}) |
|
|
|
if step % 50 == 0: |
|
log = "Epoch {} Step {} Train Loss: {:.4f}" |
|
print_sys(log.format(epoch + 1, step + 1, loss.item())) |
|
|
|
scheduler.step() |
|
|
|
train_res = evaluate(train_loader, self.model, |
|
self.config['uncertainty'], self.device) |
|
val_res = evaluate(val_loader, self.model, |
|
self.config['uncertainty'], self.device) |
|
train_metrics, _ = compute_metrics(train_res) |
|
val_metrics, _ = compute_metrics(val_res) |
|
|
|
|
|
log = "Epoch {}: Train Overall MSE: {:.4f} " \ |
|
"Validation Overall MSE: {:.4f}. " |
|
print_sys(log.format(epoch + 1, train_metrics['mse'], |
|
val_metrics['mse'])) |
|
|
|
|
|
log = "Train Top 20 DE MSE: {:.4f} " \ |
|
"Validation Top 20 DE MSE: {:.4f}. " |
|
print_sys(log.format(train_metrics['mse_de'], |
|
val_metrics['mse_de'])) |
|
|
|
if self.wandb: |
|
metrics = ['mse', 'pearson'] |
|
for m in metrics: |
|
self.wandb.log({'train_' + m: train_metrics[m], |
|
'val_'+m: val_metrics[m], |
|
'train_de_' + m: train_metrics[m + '_de'], |
|
'val_de_'+m: val_metrics[m + '_de']}) |
|
|
|
if val_metrics['mse_de'] < min_val: |
|
min_val = val_metrics['mse_de'] |
|
best_model = deepcopy(self.model) |
|
|
|
print_sys("Done!") |
|
self.best_model = best_model |
|
|
|
if 'test_loader' not in self.dataloader: |
|
print_sys('Done! No test dataloader detected.') |
|
return |
|
|
|
|
|
test_loader = self.dataloader['test_loader'] |
|
print_sys("Start Testing...") |
|
test_res = evaluate(test_loader, self.best_model, |
|
self.config['uncertainty'], self.device) |
|
test_metrics, test_pert_res = compute_metrics(test_res) |
|
log = "Best performing model: Test Top 20 DE MSE: {:.4f}" |
|
print_sys(log.format(test_metrics['mse_de'])) |
|
|
|
if self.wandb: |
|
metrics = ['mse', 'pearson'] |
|
for m in metrics: |
|
self.wandb.log({'test_' + m: test_metrics[m], |
|
'test_de_'+m: test_metrics[m + '_de'] |
|
}) |
|
|
|
print_sys('Done!') |
|
self.test_metrics = test_metrics |
|
|
|
def np_pearson_cor(x, y): |
|
xv = x - x.mean(axis=0) |
|
yv = y - y.mean(axis=0) |
|
xvss = (xv * xv).sum(axis=0) |
|
yvss = (yv * yv).sum(axis=0) |
|
result = np.matmul(xv.transpose(), yv) / np.sqrt(np.outer(xvss, yvss)) |
|
|
|
return np.maximum(np.minimum(result, 1.0), -1.0) |
|
|
|
|
|
class GeneSimNetwork(): |
|
""" |
|
GeneSimNetwork class |
|
|
|
Args: |
|
edge_list (pd.DataFrame): edge list of the network |
|
gene_list (list): list of gene names |
|
node_map (dict): dictionary mapping gene names to node indices |
|
|
|
Attributes: |
|
edge_index (torch.Tensor): edge index of the network |
|
edge_weight (torch.Tensor): edge weight of the network |
|
G (nx.DiGraph): networkx graph object |
|
""" |
|
def __init__(self, edge_list, gene_list, node_map): |
|
""" |
|
Initialize GeneSimNetwork class |
|
""" |
|
|
|
self.edge_list = edge_list |
|
self.G = nx.from_pandas_edgelist(self.edge_list, source='source', |
|
target='target', edge_attr=['importance'], |
|
create_using=nx.DiGraph()) |
|
self.gene_list = gene_list |
|
for n in self.gene_list: |
|
if n not in self.G.nodes(): |
|
self.G.add_node(n) |
|
|
|
edge_index_ = [(node_map[e[0]], node_map[e[1]]) for e in |
|
self.G.edges] |
|
self.edge_index = torch.tensor(edge_index_, dtype=torch.long).T |
|
|
|
|
|
edge_attr = nx.get_edge_attributes(self.G, 'importance') |
|
importance = np.array([edge_attr[e] for e in self.G.edges]) |
|
self.edge_weight = torch.Tensor(importance) |
|
|
|
def get_GO_edge_list(args): |
|
""" |
|
Get gene ontology edge list |
|
""" |
|
g1, gene2go = args |
|
edge_list = [] |
|
for g2 in gene2go.keys(): |
|
score = len(gene2go[g1].intersection(gene2go[g2])) / len( |
|
gene2go[g1].union(gene2go[g2])) |
|
if score > 0.1: |
|
edge_list.append((g1, g2, score)) |
|
return edge_list |
|
|
|
def make_GO(data_path, pert_list, data_name, num_workers=25, save=True): |
|
""" |
|
Creates Gene Ontology graph from a custom set of genes |
|
""" |
|
|
|
fname = './data/go_essential_' + data_name + '.csv' |
|
if os.path.exists(fname): |
|
return pd.read_csv(fname) |
|
|
|
with open(os.path.join(data_path, 'gene2go_all.pkl'), 'rb') as f: |
|
gene2go = pickle.load(f) |
|
gene2go = {i: gene2go[i] for i in pert_list} |
|
|
|
print('Creating custom GO graph, this can take a few minutes') |
|
with Pool(num_workers) as p: |
|
all_edge_list = list( |
|
tqdm(p.imap(get_GO_edge_list, ((g, gene2go) for g in gene2go.keys())), |
|
total=len(gene2go.keys()))) |
|
edge_list = [] |
|
for i in all_edge_list: |
|
edge_list = edge_list + i |
|
|
|
df_edge_list = pd.DataFrame(edge_list).rename( |
|
columns={0: 'source', 1: 'target', 2: 'importance'}) |
|
|
|
if save: |
|
print('Saving edge_list to file') |
|
df_edge_list.to_csv(fname, index=False) |
|
|
|
return df_edge_list |
|
|
|
def get_similarity_network(network_type, adata, threshold, k, |
|
data_path, data_name, split, seed, train_gene_set_size, |
|
set2conditions, default_pert_graph=True, pert_list=None): |
|
|
|
if network_type == 'co-express': |
|
df_out = get_coexpression_network_from_train(adata, threshold, k, |
|
data_path, data_name, split, |
|
seed, train_gene_set_size, |
|
set2conditions) |
|
elif network_type == 'go': |
|
if default_pert_graph: |
|
server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934319' |
|
|
|
|
|
|
|
df_jaccard = pd.read_csv(os.path.join(data_path, |
|
'go_essential_all/go_essential_all.csv')) |
|
|
|
else: |
|
df_jaccard = make_GO(data_path, pert_list, data_name) |
|
|
|
df_out = df_jaccard.groupby('target').apply(lambda x: x.nlargest(k + 1, |
|
['importance'])).reset_index(drop = True) |
|
|
|
return df_out |
|
|
|
def get_coexpression_network_from_train(adata, threshold, k, data_path, |
|
data_name, split, seed, train_gene_set_size, |
|
set2conditions): |
|
""" |
|
Infer co-expression network from training data |
|
|
|
Args: |
|
adata (anndata.AnnData): anndata object |
|
threshold (float): threshold for co-expression |
|
k (int): number of edges to keep |
|
data_path (str): path to data |
|
data_name (str): name of dataset |
|
split (str): split of dataset |
|
seed (int): seed for random number generator |
|
train_gene_set_size (int): size of training gene set |
|
set2conditions (dict): dictionary of perturbations to conditions |
|
""" |
|
|
|
fname = os.path.join(os.path.join(data_path, data_name), split + '_' + |
|
str(seed) + '_' + str(train_gene_set_size) + '_' + |
|
str(threshold) + '_' + str(k) + |
|
'_co_expression_network.csv') |
|
|
|
if os.path.exists(fname): |
|
return pd.read_csv(fname) |
|
else: |
|
gene_list = [f for f in adata.var.gene_name.values] |
|
idx2gene = dict(zip(range(len(gene_list)), gene_list)) |
|
X = adata.X |
|
train_perts = set2conditions['train'] |
|
X_tr = X[np.isin(adata.obs.condition, [i for i in train_perts if 'ctrl' in i])] |
|
gene_list = adata.var['gene_name'].values |
|
|
|
X_tr = X_tr.toarray() |
|
out = np_pearson_cor(X_tr, X_tr) |
|
out[np.isnan(out)] = 0 |
|
out = np.abs(out) |
|
|
|
out_sort_idx = np.argsort(out)[:, -(k + 1):] |
|
out_sort_val = np.sort(out)[:, -(k + 1):] |
|
|
|
df_g = [] |
|
for i in range(out_sort_idx.shape[0]): |
|
target = idx2gene[i] |
|
for j in range(out_sort_idx.shape[1]): |
|
df_g.append((idx2gene[out_sort_idx[i, j]], target, out_sort_val[i, j])) |
|
|
|
df_g = [i for i in df_g if i[2] > threshold] |
|
df_co_expression = pd.DataFrame(df_g).rename(columns = {0: 'source', |
|
1: 'target', |
|
2: 'importance'}) |
|
df_co_expression.to_csv(fname, index = False) |
|
return df_co_expression |
|
|
|
def uncertainty_loss_fct(pred, logvar, y, perts, reg = 0.1, ctrl = None, |
|
direction_lambda = 1e-3, dict_filter = None): |
|
""" |
|
Uncertainty loss function |
|
|
|
Args: |
|
pred (torch.tensor): predicted values |
|
logvar (torch.tensor): log variance |
|
y (torch.tensor): true values |
|
perts (list): list of perturbations |
|
reg (float): regularization parameter |
|
ctrl (str): control perturbation |
|
direction_lambda (float): direction loss weight hyperparameter |
|
dict_filter (dict): dictionary of perturbations to conditions |
|
|
|
""" |
|
gamma = 2 |
|
perts = np.array(perts) |
|
losses = torch.tensor(0.0, requires_grad=True).to(pred.device) |
|
for p in set(perts): |
|
if p!= 'ctrl': |
|
retain_idx = dict_filter[p] |
|
pred_p = pred[np.where(perts==p)[0]][:, retain_idx] |
|
y_p = y[np.where(perts==p)[0]][:, retain_idx] |
|
logvar_p = logvar[np.where(perts==p)[0]][:, retain_idx] |
|
else: |
|
pred_p = pred[np.where(perts==p)[0]] |
|
y_p = y[np.where(perts==p)[0]] |
|
logvar_p = logvar[np.where(perts==p)[0]] |
|
|
|
|
|
losses += torch.sum((pred_p - y_p)**(2 + gamma) + reg * torch.exp( |
|
-logvar_p) * (pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] |
|
|
|
|
|
if p!= 'ctrl': |
|
losses += torch.sum(direction_lambda * |
|
(torch.sign(y_p - ctrl[retain_idx]) - |
|
torch.sign(pred_p - ctrl[retain_idx]))**2)/\ |
|
pred_p.shape[0]/pred_p.shape[1] |
|
else: |
|
losses += torch.sum(direction_lambda * |
|
(torch.sign(y_p - ctrl) - |
|
torch.sign(pred_p - ctrl))**2)/\ |
|
pred_p.shape[0]/pred_p.shape[1] |
|
|
|
return losses/(len(set(perts))) |
|
|
|
|
|
def loss_fct(pred, y, perts, ctrl = None, direction_lambda = 1e-3, dict_filter = None): |
|
""" |
|
Main MSE Loss function, includes direction loss |
|
|
|
Args: |
|
pred (torch.tensor): predicted values |
|
y (torch.tensor): true values |
|
perts (list): list of perturbations |
|
ctrl (str): control perturbation |
|
direction_lambda (float): direction loss weight hyperparameter |
|
dict_filter (dict): dictionary of perturbations to conditions |
|
|
|
""" |
|
gamma = 2 |
|
mse_p = torch.nn.MSELoss() |
|
perts = np.array(perts) |
|
losses = torch.tensor(0.0, requires_grad=True).to(pred.device) |
|
|
|
for p in set(perts): |
|
pert_idx = np.where(perts == p)[0] |
|
|
|
|
|
|
|
if p!= 'ctrl': |
|
retain_idx = dict_filter[p] |
|
pred_p = pred[pert_idx][:, retain_idx] |
|
y_p = y[pert_idx][:, retain_idx] |
|
else: |
|
pred_p = pred[pert_idx] |
|
y_p = y[pert_idx] |
|
losses = losses + torch.sum((pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] |
|
|
|
|
|
if (p!= 'ctrl'): |
|
losses = losses + torch.sum(direction_lambda * |
|
(torch.sign(y_p - ctrl[retain_idx]) - |
|
torch.sign(pred_p - ctrl[retain_idx]))**2)/\ |
|
pred_p.shape[0]/pred_p.shape[1] |
|
else: |
|
losses = losses + torch.sum(direction_lambda * (torch.sign(y_p - ctrl) - |
|
torch.sign(pred_p - ctrl))**2)/\ |
|
pred_p.shape[0]/pred_p.shape[1] |
|
return losses/(len(set(perts))) |
|
def evaluate(loader, model, uncertainty, device): |
|
""" |
|
Run model in inference mode using a given data loader |
|
""" |
|
|
|
model.eval() |
|
model.to(device) |
|
pert_cat = [] |
|
pred = [] |
|
truth = [] |
|
pred_de = [] |
|
truth_de = [] |
|
results = {} |
|
logvar = [] |
|
|
|
for itr, batch in enumerate(loader): |
|
|
|
batch.to(device) |
|
pert_cat.extend(batch.pert) |
|
|
|
with torch.no_grad(): |
|
if uncertainty: |
|
p, unc = model(batch) |
|
logvar.extend(unc.cpu()) |
|
else: |
|
p = model(batch) |
|
t = batch.y |
|
pred.extend(p.cpu()) |
|
truth.extend(t.cpu()) |
|
|
|
|
|
for itr, de_idx in enumerate(batch.de_idx): |
|
pred_de.append(p[itr, de_idx]) |
|
truth_de.append(t[itr, de_idx]) |
|
|
|
|
|
results['pert_cat'] = np.array(pert_cat) |
|
pred = torch.stack(pred) |
|
truth = torch.stack(truth) |
|
results['pred']= pred.detach().cpu().numpy() |
|
results['truth']= truth.detach().cpu().numpy() |
|
|
|
pred_de = torch.stack(pred_de) |
|
truth_de = torch.stack(truth_de) |
|
results['pred_de']= pred_de.detach().cpu().numpy() |
|
results['truth_de']= truth_de.detach().cpu().numpy() |
|
|
|
if uncertainty: |
|
results['logvar'] = torch.stack(logvar).detach().cpu().numpy() |
|
|
|
return results |
|
|
|
|
|
def compute_metrics(results): |
|
""" |
|
Given results from a model run and the ground truth, compute metrics |
|
|
|
""" |
|
metrics = {} |
|
metrics_pert = {} |
|
|
|
metric2fct = { |
|
'mse': mse, |
|
'pearson': pearsonr |
|
} |
|
|
|
for m in metric2fct.keys(): |
|
metrics[m] = [] |
|
metrics[m + '_de'] = [] |
|
|
|
for pert in np.unique(results['pert_cat']): |
|
|
|
metrics_pert[pert] = {} |
|
p_idx = np.where(results['pert_cat'] == pert)[0] |
|
|
|
for m, fct in metric2fct.items(): |
|
if m == 'pearson': |
|
val = fct(results['pred'][p_idx].mean(0), results['truth'][p_idx].mean(0))[0] |
|
if np.isnan(val): |
|
val = 0 |
|
else: |
|
val = fct(results['pred'][p_idx].mean(0), results['truth'][p_idx].mean(0)) |
|
|
|
metrics_pert[pert][m] = val |
|
metrics[m].append(metrics_pert[pert][m]) |
|
|
|
|
|
if pert != 'ctrl': |
|
|
|
for m, fct in metric2fct.items(): |
|
if m == 'pearson': |
|
val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0] |
|
if np.isnan(val): |
|
val = 0 |
|
else: |
|
val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0)) |
|
|
|
metrics_pert[pert][m + '_de'] = val |
|
metrics[m + '_de'].append(metrics_pert[pert][m + '_de']) |
|
|
|
else: |
|
for m, fct in metric2fct.items(): |
|
metrics_pert[pert][m + '_de'] = 0 |
|
|
|
for m in metric2fct.keys(): |
|
|
|
metrics[m] = np.mean(metrics[m]) |
|
metrics[m + '_de'] = np.mean(metrics[m + '_de']) |
|
|
|
return metrics, metrics_pert |
|
|
|
def filter_pert_in_go(condition, pert_names): |
|
""" |
|
Filter perturbations in GO graph |
|
|
|
Args: |
|
condition (str): whether condition is 'ctrl' or not |
|
pert_names (list): list of perturbations |
|
""" |
|
|
|
if condition == 'ctrl': |
|
return True |
|
else: |
|
cond1 = condition.split('+')[0] |
|
cond2 = condition.split('+')[1] |
|
num_ctrl = (cond1 == 'ctrl') + (cond2 == 'ctrl') |
|
num_in_perts = (cond1 in pert_names) + (cond2 in pert_names) |
|
if num_ctrl + num_in_perts == 2: |
|
return True |
|
else: |
|
return False |
|
|
|
class PertData: |
|
def __init__(self, data_path, |
|
gene_set_path=None, |
|
default_pert_graph=True): |
|
|
|
|
|
self.data_path = data_path |
|
self.default_pert_graph = default_pert_graph |
|
self.gene_set_path = gene_set_path |
|
self.dataset_name = None |
|
self.dataset_path = None |
|
self.adata = None |
|
self.dataset_processed = None |
|
self.ctrl_adata = None |
|
self.gene_names = [] |
|
self.node_map = {} |
|
|
|
|
|
self.split = None |
|
self.seed = None |
|
self.subgroup = None |
|
self.train_gene_set_size = None |
|
|
|
if not os.path.exists(self.data_path): |
|
os.mkdir(self.data_path) |
|
server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417' |
|
with open(os.path.join(self.data_path, 'gene2go_all.pkl'), 'rb') as f: |
|
self.gene2go = pickle.load(f) |
|
|
|
def set_pert_genes(self): |
|
""" |
|
Set the list of genes that can be perturbed and are to be included in |
|
perturbation graph |
|
""" |
|
|
|
if self.gene_set_path is not None: |
|
|
|
path_ = self.gene_set_path |
|
self.default_pert_graph = False |
|
with open(path_, 'rb') as f: |
|
essential_genes = pickle.load(f) |
|
|
|
elif self.default_pert_graph is False: |
|
|
|
all_pert_genes = get_genes_from_perts(self.adata.obs['condition']) |
|
essential_genes = list(self.adata.var['gene_name'].values) |
|
essential_genes += all_pert_genes |
|
|
|
else: |
|
|
|
server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934320' |
|
path_ = os.path.join(self.data_path, |
|
'essential_all_data_pert_genes.pkl') |
|
with open(path_, 'rb') as f: |
|
essential_genes = pickle.load(f) |
|
|
|
gene2go = {i: self.gene2go[i] for i in essential_genes if i in self.gene2go} |
|
|
|
self.pert_names = np.unique(list(gene2go.keys())) |
|
self.node_map_pert = {x: it for it, x in enumerate(self.pert_names)} |
|
|
|
def load(self, data_name = None, data_path = None): |
|
if data_name in ['norman', 'adamson', 'dixit', |
|
'replogle_k562_essential', |
|
'replogle_rpe1_essential']: |
|
data_path = os.path.join(self.data_path, data_name) |
|
|
|
self.dataset_name = data_path.split('/')[-1] |
|
self.dataset_path = data_path |
|
adata_path = os.path.join(data_path, 'perturb_processed.h5ad') |
|
self.adata = sc.read_h5ad(adata_path) |
|
|
|
elif os.path.exists(data_path): |
|
adata_path = os.path.join(data_path, 'perturb_processed.h5ad') |
|
self.adata = sc.read_h5ad(adata_path) |
|
self.dataset_name = data_path.split('/')[-1] |
|
self.dataset_path = data_path |
|
else: |
|
raise ValueError("data attribute is either norman, adamson, dixit " |
|
"replogle_k562 or replogle_rpe1 " |
|
"or a path to an h5ad file") |
|
|
|
self.set_pert_genes() |
|
print_sys('These perturbations are not in the GO graph and their ' |
|
'perturbation can thus not be predicted') |
|
not_in_go_pert = np.array(self.adata.obs[ |
|
self.adata.obs.condition.apply( |
|
lambda x:not filter_pert_in_go(x, |
|
self.pert_names))].condition.unique()) |
|
print_sys(not_in_go_pert) |
|
|
|
filter_go = self.adata.obs[self.adata.obs.condition.apply( |
|
lambda x: filter_pert_in_go(x, self.pert_names))] |
|
self.adata = self.adata[filter_go.index.values, :] |
|
pyg_path = os.path.join(data_path, 'data_pyg') |
|
if not os.path.exists(pyg_path): |
|
os.mkdir(pyg_path) |
|
dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl') |
|
|
|
if os.path.isfile(dataset_fname): |
|
print_sys("Local copy of pyg dataset is detected. Loading...") |
|
self.dataset_processed = pickle.load(open(dataset_fname, "rb")) |
|
print_sys("Done!") |
|
else: |
|
self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] |
|
self.gene_names = self.adata.var.gene_name |
|
|
|
|
|
print_sys("Creating pyg object for each cell in the data...") |
|
self.create_dataset_file() |
|
print_sys("Saving new dataset pyg object at " + dataset_fname) |
|
pickle.dump(self.dataset_processed, open(dataset_fname, "wb")) |
|
print_sys("Done!") |
|
|
|
|
|
def prepare_split(self, split = 'simulation', |
|
seed = 1, |
|
train_gene_set_size = 0.75, |
|
combo_seen2_train_frac = 0.75, |
|
combo_single_split_test_set_fraction = 0.1, |
|
test_perts = None, |
|
only_test_set_perts = False, |
|
test_pert_genes = None, |
|
split_dict_path=None): |
|
|
|
""" |
|
Prepare splits for training and testing |
|
|
|
Parameters |
|
---------- |
|
split: str |
|
Type of split to use. Currently, we support 'simulation', |
|
'simulation_single', 'combo_seen0', 'combo_seen1', 'combo_seen2', |
|
'single', 'no_test', 'no_split', 'custom' |
|
seed: int |
|
Random seed |
|
train_gene_set_size: float |
|
Fraction of genes to use for training |
|
combo_seen2_train_frac: float |
|
Fraction of combo seen2 perturbations to use for training |
|
combo_single_split_test_set_fraction: float |
|
Fraction of combo single perturbations to use for testing |
|
test_perts: list |
|
List of perturbations to use for testing |
|
only_test_set_perts: bool |
|
If True, only use test set perturbations for testing |
|
test_pert_genes: list |
|
List of genes to use for testing |
|
split_dict_path: str |
|
Path to dictionary used for custom split. Sample format: |
|
{'train': [X, Y], 'val': [P, Q], 'test': [Z]} |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
""" |
|
available_splits = ['simulation', 'simulation_single', 'combo_seen0', |
|
'combo_seen1', 'combo_seen2', 'single', 'no_test', |
|
'no_split', 'custom'] |
|
if split not in available_splits: |
|
raise ValueError('currently, we only support ' + ','.join(available_splits)) |
|
self.split = split |
|
self.seed = seed |
|
self.subgroup = None |
|
|
|
if split == 'custom': |
|
try: |
|
with open(split_dict_path, 'rb') as f: |
|
self.set2conditions = pickle.load(f) |
|
except: |
|
raise ValueError('Please set split_dict_path for custom split') |
|
return |
|
|
|
self.train_gene_set_size = train_gene_set_size |
|
split_folder = os.path.join(self.dataset_path, 'splits') |
|
if not os.path.exists(split_folder): |
|
os.mkdir(split_folder) |
|
split_file = self.dataset_name + '_' + split + '_' + str(seed) + '_' \ |
|
+ str(train_gene_set_size) + '.pkl' |
|
split_path = os.path.join(split_folder, split_file) |
|
|
|
if test_perts: |
|
split_path = split_path[:-4] + '_' + test_perts + '.pkl' |
|
|
|
if os.path.exists(split_path): |
|
print('here1') |
|
print_sys("Local copy of split is detected. Loading...") |
|
set2conditions = pickle.load(open(split_path, "rb")) |
|
if split == 'simulation': |
|
subgroup_path = split_path[:-4] + '_subgroup.pkl' |
|
subgroup = pickle.load(open(subgroup_path, "rb")) |
|
self.subgroup = subgroup |
|
else: |
|
print_sys("Creating new splits....") |
|
if test_perts: |
|
test_perts = test_perts.split('_') |
|
|
|
if split in ['simulation', 'simulation_single']: |
|
|
|
DS = DataSplitter(self.adata, split_type=split) |
|
|
|
adata, subgroup = DS.split_data(train_gene_set_size = train_gene_set_size, |
|
combo_seen2_train_frac = combo_seen2_train_frac, |
|
seed=seed, |
|
test_perts = test_perts, |
|
only_test_set_perts = only_test_set_perts |
|
) |
|
subgroup_path = split_path[:-4] + '_subgroup.pkl' |
|
pickle.dump(subgroup, open(subgroup_path, "wb")) |
|
self.subgroup = subgroup |
|
|
|
elif split[:5] == 'combo': |
|
|
|
split_type = 'combo' |
|
seen = int(split[-1]) |
|
|
|
if test_pert_genes: |
|
test_pert_genes = test_pert_genes.split('_') |
|
|
|
DS = DataSplitter(self.adata, split_type=split_type, seen=int(seen)) |
|
adata = DS.split_data(test_size=combo_single_split_test_set_fraction, |
|
test_perts=test_perts, |
|
test_pert_genes=test_pert_genes, |
|
seed=seed) |
|
|
|
elif split == 'single': |
|
|
|
DS = DataSplitter(self.adata, split_type=split) |
|
adata = DS.split_data(test_size=combo_single_split_test_set_fraction, |
|
seed=seed) |
|
|
|
elif split == 'no_test': |
|
|
|
DS = DataSplitter(self.adata, split_type=split) |
|
adata = DS.split_data(seed=seed) |
|
|
|
elif split == 'no_split': |
|
|
|
adata = self.adata |
|
adata.obs['split'] = 'test' |
|
|
|
set2conditions = dict(adata.obs.groupby('split').agg({'condition': |
|
lambda x: x}).condition) |
|
set2conditions = {i: j.unique().tolist() for i,j in set2conditions.items()} |
|
pickle.dump(set2conditions, open(split_path, "wb")) |
|
print_sys("Saving new splits at " + split_path) |
|
|
|
self.set2conditions = set2conditions |
|
|
|
if split == 'simulation': |
|
print_sys('Simulation split test composition:') |
|
for i,j in subgroup['test_subgroup'].items(): |
|
print_sys(i + ':' + str(len(j))) |
|
print_sys("Done!") |
|
|
|
def get_dataloader(self, batch_size, test_batch_size = None): |
|
""" |
|
Get dataloaders for training and testing |
|
|
|
Parameters |
|
---------- |
|
batch_size: int |
|
Batch size for training |
|
test_batch_size: int |
|
Batch size for testing |
|
|
|
Returns |
|
------- |
|
dict |
|
Dictionary of dataloaders |
|
|
|
""" |
|
if test_batch_size is None: |
|
test_batch_size = batch_size |
|
|
|
self.node_map = {x: it for it, x in enumerate(self.adata.var.gene_name)} |
|
self.gene_names = self.adata.var.gene_name |
|
|
|
|
|
cell_graphs = {} |
|
if self.split == 'no_split': |
|
i = 'test' |
|
cell_graphs[i] = [] |
|
for p in self.set2conditions[i]: |
|
if p != 'ctrl': |
|
cell_graphs[i].extend(self.dataset_processed[p]) |
|
|
|
print_sys("Creating dataloaders....") |
|
|
|
test_loader = DataLoader(cell_graphs['test'], |
|
batch_size=batch_size, shuffle=False) |
|
|
|
print_sys("Dataloaders created...") |
|
return {'test_loader': test_loader} |
|
else: |
|
if self.split =='no_test': |
|
splits = ['train','val'] |
|
else: |
|
splits = ['train','val','test'] |
|
for i in splits: |
|
cell_graphs[i] = [] |
|
for p in self.set2conditions[i]: |
|
cell_graphs[i].extend(self.dataset_processed[p]) |
|
|
|
print_sys("Creating dataloaders....") |
|
|
|
|
|
train_loader = DataLoader(cell_graphs['train'], |
|
batch_size=batch_size, shuffle=True, drop_last = True) |
|
val_loader = DataLoader(cell_graphs['val'], |
|
batch_size=batch_size, shuffle=True) |
|
|
|
if self.split !='no_test': |
|
test_loader = DataLoader(cell_graphs['test'], |
|
batch_size=batch_size, shuffle=False) |
|
self.dataloader = {'train_loader': train_loader, |
|
'val_loader': val_loader, |
|
'test_loader': test_loader} |
|
|
|
else: |
|
self.dataloader = {'train_loader': train_loader, |
|
'val_loader': val_loader} |
|
print_sys("Done!") |
|
|
|
def get_pert_idx(self, pert_category): |
|
""" |
|
Get perturbation index for a given perturbation category |
|
|
|
Parameters |
|
---------- |
|
pert_category: str |
|
Perturbation category |
|
|
|
Returns |
|
------- |
|
list |
|
List of perturbation indices |
|
|
|
""" |
|
try: |
|
pert_idx = [np.where(p == self.pert_names)[0][0] |
|
for p in pert_category.split('+') |
|
if p != 'ctrl'] |
|
except: |
|
print(pert_category) |
|
pert_idx = None |
|
|
|
return pert_idx |
|
|
|
def create_cell_graph(self, X, y, de_idx, pert, pert_idx=None): |
|
""" |
|
Create a cell graph from a given cell |
|
|
|
Parameters |
|
---------- |
|
X: np.ndarray |
|
Gene expression matrix |
|
y: np.ndarray |
|
Label vector |
|
de_idx: np.ndarray |
|
DE gene indices |
|
pert: str |
|
Perturbation category |
|
pert_idx: list |
|
List of perturbation indices |
|
|
|
Returns |
|
------- |
|
torch_geometric.data.Data |
|
Cell graph to be used in dataloader |
|
|
|
""" |
|
|
|
feature_mat = torch.Tensor(X).T |
|
if pert_idx is None: |
|
pert_idx = [-1] |
|
return Data(x=feature_mat, pert_idx=pert_idx, |
|
y=torch.Tensor(y), de_idx=de_idx, pert=pert) |
|
|
|
def create_cell_graph_dataset(self, split_adata, pert_category, |
|
num_samples=1): |
|
""" |
|
Combine cell graphs to create a dataset of cell graphs |
|
|
|
Parameters |
|
---------- |
|
split_adata: anndata.AnnData |
|
Annotated data matrix |
|
pert_category: str |
|
Perturbation category |
|
num_samples: int |
|
Number of samples to create per perturbed cell (i.e. number of |
|
control cells to map to each perturbed cell) |
|
|
|
Returns |
|
------- |
|
list |
|
List of cell graphs |
|
|
|
""" |
|
|
|
num_de_genes = 20 |
|
adata_ = split_adata[split_adata.obs['condition'] == pert_category] |
|
if 'rank_genes_groups_cov_all' in adata_.uns: |
|
de_genes = adata_.uns['rank_genes_groups_cov_all'] |
|
de = True |
|
else: |
|
de = False |
|
num_de_genes = 1 |
|
Xs = [] |
|
ys = [] |
|
|
|
|
|
if pert_category != 'ctrl': |
|
|
|
pert_idx = self.get_pert_idx(pert_category) |
|
|
|
|
|
pert_de_category = adata_.obs['condition_name'][0] |
|
if de: |
|
de_idx = np.where(adata_.var_names.isin( |
|
np.array(de_genes[pert_de_category][:num_de_genes])))[0] |
|
else: |
|
de_idx = [-1] * num_de_genes |
|
for cell_z in adata_.X: |
|
|
|
ctrl_samples = self.ctrl_adata[np.random.randint(0, |
|
len(self.ctrl_adata), num_samples), :] |
|
for c in ctrl_samples.X: |
|
Xs.append(c) |
|
ys.append(cell_z) |
|
|
|
|
|
else: |
|
pert_idx = None |
|
de_idx = [-1] * num_de_genes |
|
for cell_z in adata_.X: |
|
Xs.append(cell_z) |
|
ys.append(cell_z) |
|
|
|
|
|
cell_graphs = [] |
|
for X, y in zip(Xs, ys): |
|
cell_graphs.append(self.create_cell_graph(X.toarray(), |
|
y.toarray(), de_idx, pert_category, pert_idx)) |
|
|
|
return cell_graphs |
|
|
|
def create_dataset_file(self): |
|
""" |
|
Create dataset file for each perturbation condition |
|
""" |
|
print_sys("Creating dataset file...") |
|
self.dataset_processed = {} |
|
for p in tqdm(self.adata.obs['condition'].unique()): |
|
self.dataset_processed[p] = self.create_cell_graph_dataset(self.adata, p) |
|
print_sys("Done!") |
|
|
|
|
|
def main(data_path='./data', out_dir='./saved_models', device='cuda:0'): |
|
os.makedirs(data_path, exist_ok=True) |
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
os.environ["WANDB_SILENT"] = "true" |
|
os.environ["WANDB_ERROR_REPORTING"] = "false" |
|
|
|
print_sys("=== data loading ===") |
|
pert_data = PertData(data_path) |
|
|
|
pert_data.load(data_name='norman') |
|
|
|
pert_data.prepare_split(split='simulation', seed=1) |
|
pert_data.get_dataloader(batch_size=32, test_batch_size=128) |
|
|
|
print_sys("\n=== model traing ===") |
|
gears_model = GEARS( |
|
pert_data, |
|
device=device, |
|
weight_bias_track=True, |
|
proj_name='GEARS', |
|
exp_name='gears_norman' |
|
) |
|
gears_model.model_initialize(hidden_size = 64) |
|
|
|
gears_model.train(epochs=args.epochs, lr=1e-3) |
|
|
|
gears_model.save_model(os.path.join(out_dir, 'norman_full_model')) |
|
print_sys(f"model saved to {out_dir}") |
|
gears_model.load_pretrained(os.path.join(out_dir, 'norman_full_model')) |
|
|
|
final_infos = { |
|
"Gears":{ |
|
"means":{ |
|
"Test Top 20 DE MSE": float(gears_model.test_metrics['mse_de'].item()) |
|
} |
|
} |
|
} |
|
|
|
with open(os.path.join(out_dir, 'final_info.json'), 'w') as f: |
|
json.dump(final_infos, f, indent=4) |
|
print_sys("final info saved.") |
|
|
|
def print_sys(s): |
|
"""system print |
|
|
|
Args: |
|
s (str): the string to print |
|
""" |
|
print(s, flush = True, file = sys.stderr) |
|
log_path = os.path.join(args.out_dir, args.log_file) |
|
logging.basicConfig( |
|
filename=log_path, |
|
level=logging.INFO, |
|
) |
|
logger = logging.getLogger() |
|
logger.info(s) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--data_path', type=str, default='./data') |
|
parser.add_argument('--out_dir', type=str, default='run_1') |
|
parser.add_argument('--device', type=str, default='cuda:0') |
|
parser.add_argument('--log_file', type=str, default="training_ds.log") |
|
parser.add_argument('--epochs', type=int, default=20) |
|
args = parser.parse_args() |
|
|
|
try: |
|
main( |
|
data_path=args.data_path, |
|
out_dir=args.out_dir, |
|
device=args.device |
|
) |
|
except Exception as e: |
|
print("Origin error in main process:", flush=True) |
|
traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w")) |
|
raise |
|
|
|
|