PreMode / analysis /fig.4ab.R
gzhong's picture
Upload folder using huggingface_hub
7718235 verified
library(ggplot2)
task.dic <- list("PTEN"=c("score.1"="stability", "score.2"="enzyme.activity"),
"NUDT15"=c("score.1"="stability", "score.2"="enzyme.activity"),
"VKORC1"=c("score.1"="enzyme.activity", "score.2"="stability"),
"CCR5"=c("score.1"="stability", "score.2"="binding Ab2D7", "score.3"="binding HIV-1"),
"CXCR4"=c("score.1"="stability", "score.2"="binding CXCL12", "score.3"="binding Ab12G5"),
"SNCA"=c("score.1"="enzyme.activity", "score.2"="stability"),
"CYP2C9"=c("score.1"="enzyme.activity", "score.2"="stability"),
"GCK"=c("score.1"="enzyme.activity", "score.2"="stability"),
"ASPA"=c("score.1"="stability", "score.2"="enzyme.activity")
)
alphabet_premode <- c('L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D',
'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C')
genes <- c("PTEN", "NUDT15", "CCR5", "CXCR4", 'SNCA', 'CYP2C9', 'GCK', 'ASPA')
# add baseline AUC
# esm alphabets
source('./AUROC.R')
alphabet <- c('<cls>', '<pad>', '<eos>', '<unk>',
'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D',
'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C',
'X', 'B', 'U', 'Z', 'O', '.', '-',
'<null_1>', '<mask>')
result <- data.frame()
for (i in 1:length(genes)) {
for (fold in 0:4) {
# REVEL, PrimateAI, ESM AUC
test.result <- read.csv(paste0('PreMode/', genes[i], '/',
'/test.fold.', fold, '.annotated.csv'), row.names = 1)
test.result.pass <- read.csv(paste0('ESM.SLP/', genes[i], '/',
'/testing.fold.', fold, '.csv'))
task.length <- length(task.dic[[genes[i]]])
# add hsu et al results
hsu.unirep_onehot.auc <- list(R2=c())
hsu.ev_onehot.auc <- list(R2=c())
hsu.gesm_onehot.auc <- list(R2=c())
hsu.eve_onehot.auc <- list(R2=c())
for (s in 1:task.length) {
test.result.hsu <- read.csv(paste0('./Hsu.et.al.git/results/',
genes[i], '.fold.', fold, '.score.', s, '/results.csv'))
hsu.unirep_onehot.auc$R2 <- c(hsu.unirep_onehot.auc$R2, test.result.hsu$spearman[match('eunirep_ll+onehot', test.result.hsu$predictor)])
hsu.ev_onehot.auc$R2 <- c(hsu.ev_onehot.auc$R2, test.result.hsu$spearman[match('ev+onehot', test.result.hsu$predictor)])
hsu.gesm_onehot.auc$R2 <- c(hsu.gesm_onehot.auc$R2, test.result.hsu$spearman[match('gesm+onehot', test.result.hsu$predictor)])
hsu.eve_onehot.auc$R2 <- c(hsu.eve_onehot.auc$R2, test.result.hsu$spearman[match('vae+onehot', test.result.hsu$predictor)])
}
PreMode.auc <- plot.R2(test.result[,names(task.dic[[genes[i]]])], test.result[,paste0("logits.", 0:(task.length-1))], bin = grepl('bin', genes[i]))
PreMode.pass.auc <- plot.R2(test.result.pass[,names(task.dic[[genes[i]]])], test.result.pass[,paste0("logits.", 0:(task.length-1))], bin = grepl('bin', genes[i]))
to.append <- data.frame(min.val.R = c(PreMode.auc$R2,
PreMode.pass.auc$R2,
hsu.unirep_onehot.auc$R2,
hsu.ev_onehot.auc$R2,
hsu.gesm_onehot.auc$R2,
hsu.eve_onehot.auc$R2),
task.name = paste0(genes[i], ":", rep(task.dic[[genes[i]]], 6)),
HGNC=genes[i],
fold=fold,
npoints=dim(test.result)[1])
to.append$model <- rep(c("PreMode",
"ESM+SLP",
"Augmented Unirep",
"Augmented EVmutation",
"Augmented ESM1b",
"Augmented EVE"), each = task.length)
result <- rbind(result, to.append)
}
}
num.models <- length(unique(result$model))
p <- ggplot(result, aes(y=min.val.R, x=task.name, col=model)) +
geom_point(alpha=0.2) +
stat_summary(data = result,
aes(x=as.numeric(factor(task.name))+0.4*(as.numeric(factor(model)))/num.models-0.2*(num.models+1)/num.models,
y = min.val.R, col=model),
fun.data = mean_se, geom = "errorbar", width = 0.2) +
stat_summary(data = result,
aes(x=as.numeric(factor(task.name))+0.4*(as.numeric(factor(model)))/num.models-0.2*(num.models+1)/num.models,
y = min.val.R, col=model),
fun.data = mean_se, geom = "point") +
labs(x = "task", y = "min.val.R", fill = "model") +
theme_bw() +
theme(axis.text.x = element_text(angle=60, vjust = 1, hjust = 1),
legend.position="bottom",
legend.direction="horizontal") +
# ylim(-1, 1) +
coord_flip() + guides(col=guide_legend(ncol=1)) + ggtitle('Transfer Learning Compare') +
ggeasy::easy_center_title() +
xlab('task: Molecular mode-of-action') + ylab('Spearman Rho')
ggsave(paste0('figs/fig.4a.pdf'), p, height = 8, width = 4)
# plot the task weighted averages as well as task size weighted error bars
uniq.result.plot <- result[result$fold==0,]
for (i in 1:dim(uniq.result.plot)[1]) {
uniq.result.plot$rho[i] = mean(result$min.val.R[result$model==uniq.result.plot$model[i] &
result$task.name==uniq.result.plot$task.name[i]], na.rm=T)
uniq.result.plot$rho.sd[i] = sd(result$min.val.R[result$model==uniq.result.plot$model[i] &
result$task.name==uniq.result.plot$task.name[i]], na.rm=T)
}
# aggregate across models
uniq.model.result.plot <- uniq.result.plot[!duplicated(uniq.result.plot[,c('model', 'HGNC')]),]
for (i in 1:dim(uniq.model.result.plot)[1]) {
uniq.model.result.plot$stab.rho[i] <- mean(uniq.result.plot$rho[uniq.result.plot$HGNC==uniq.model.result.plot$HGNC[i] &
grepl('stability', uniq.result.plot$task.name) &
uniq.result.plot$model == uniq.model.result.plot$model[i]])
uniq.model.result.plot$stab.rho.sd[i] <- mean(uniq.result.plot$rho.sd[uniq.result.plot$HGNC==uniq.model.result.plot$HGNC[i] &
grepl('stability', uniq.result.plot$task.name) &
uniq.result.plot$model == uniq.model.result.plot$model[i]])
uniq.model.result.plot$func.rho[i] <- mean(uniq.result.plot$rho[uniq.result.plot$HGNC==uniq.model.result.plot$HGNC[i] &
!grepl('stability', uniq.result.plot$task.name) &
uniq.result.plot$model == uniq.model.result.plot$model[i]])
uniq.model.result.plot$func.rho.sd[i] <- mean(uniq.result.plot$rho.sd[uniq.result.plot$HGNC==uniq.model.result.plot$HGNC[i] &
!grepl('stability', uniq.result.plot$task.name) &
uniq.result.plot$model == uniq.model.result.plot$model[i]])
}
# aggregate across models
uniq.model.result.plot.plot <- uniq.model.result.plot[!duplicated(uniq.model.result.plot$model),]
for (i in 1:dim(uniq.model.result.plot.plot)[1]) {
task.sizes <- uniq.model.result.plot$npoints[uniq.model.result.plot$model==uniq.model.result.plot$model[i]]
uniq.model.result.plot.plot$stab.rho[i] <- sum(uniq.model.result.plot$stab.rho[uniq.model.result.plot$model==uniq.model.result.plot.plot$model[i]] * task.sizes / sum(task.sizes), na.rm=T)
uniq.model.result.plot.plot$stab.rho.sd[i] <- sum(uniq.model.result.plot$stab.rho.sd[uniq.model.result.plot$model==uniq.model.result.plot.plot$model[i]] * task.sizes / sum(task.sizes), na.rm=T)
uniq.model.result.plot.plot$func.rho[i] <- sum(uniq.model.result.plot$func.rho[uniq.model.result.plot$model==uniq.model.result.plot.plot$model[i]] * task.sizes / sum(task.sizes), na.rm=T)
uniq.model.result.plot.plot$func.rho.sd[i] <- sum(uniq.model.result.plot$func.rho.sd[uniq.model.result.plot$model==uniq.model.result.plot.plot$model[i]] * task.sizes / sum(task.sizes), na.rm=T)
}
p <- ggplot(uniq.model.result.plot.plot, aes(x=stab.rho, y=func.rho, col=model)) +
geom_point() +
geom_errorbar(aes(ymin=func.rho-func.rho.sd, ymax=func.rho+func.rho.sd), width=.02) +
geom_errorbarh(aes(xmin=stab.rho-stab.rho.sd, xmax=stab.rho+stab.rho.sd), height=.02) +
# coord_flip() +guides(col=guide_legend(ncol=2)) +
labs(x = "stab.rho", y = "func.rho", fill = "model") +
geom_abline(slope = 1, intercept = 0, linetype = "dashed", alpha=0.2) +
theme_bw() + xlim(0.15, 0.7) + ylim(0.15, 0.7) +
theme(axis.text.x = element_text(angle=60, vjust = 1, hjust = 1),
legend.position="right",
legend.direction="vertical") +
ggtitle('Transfer Learning Compare\n(Weighted Average by Dataset sizes)') +
ggeasy::easy_center_title()
ggsave('figs/fig.4b.pdf', p, height=4, width=5)