|
import argparse |
|
import torch |
|
import os |
|
os.chdir('..') |
|
from dataloader import CellLoader |
|
from matplotlib import pyplot as plt |
|
from celle_main import instantiate_from_config |
|
from omegaconf import OmegaConf |
|
from celle.utils import process_image |
|
|
|
def run_model(mode, sequence, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device): |
|
if mode == "image": |
|
run_image_prediction( |
|
sequence, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
) |
|
elif mode == "sequence": |
|
run_sequence_prediction( |
|
sequence, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
) |
|
|
|
def run_sequence_prediction( |
|
sequence_input, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
): |
|
""" |
|
Run Celle model with provided inputs and display results. |
|
|
|
:param sequence: Path to sequence file |
|
:param nucleus_image_path: Path to nucleus image |
|
:param protein_image_path: Path to protein image (optional) |
|
:param model_ckpt_path: Path to model checkpoint |
|
:param model_config_path: Path to model config |
|
""" |
|
|
|
|
|
dataset = CellLoader( |
|
sequence_mode="embedding", |
|
vocab="esm2", |
|
split_key="val", |
|
crop_method="center", |
|
resize=600, |
|
crop_size=256, |
|
text_seq_len=1000, |
|
pad_mode="end", |
|
threshold="median", |
|
) |
|
|
|
|
|
if len(sequence_input) == 0: |
|
raise ValueError("Sequence must be provided.") |
|
|
|
if "<mask>" not in sequence_input: |
|
print("Warning: Sequence does not contain any masked positions to predict.") |
|
|
|
|
|
sequence = dataset.tokenize_sequence(sequence_input) |
|
|
|
|
|
if not os.path.exists(nucleus_image_path): |
|
|
|
nucleus_image_path = 'images/nucleus.jpg' |
|
print( |
|
"Warning: No nucleus image provided. Using default nucleus image from dataset." |
|
) |
|
else: |
|
|
|
nucleus_image = process_image(nucleus_image_path) |
|
|
|
|
|
if not os.path.exists(protein_image_path): |
|
|
|
protein_image_path = 'images/protein.jpg' |
|
print( |
|
"Warning: No nucleus image provided. Using default protein image from dataset." |
|
) |
|
else: |
|
|
|
protein_image = process_image(protein_image_path) |
|
protein_image = (protein_image > torch.median(protein_image,dim=0))*1.0 |
|
|
|
|
|
config = OmegaConf.load(model_config_path) |
|
if config["model"]["params"]["ckpt_path"] is None: |
|
config["model"]["params"]["ckpt_path"] = model_ckpt_path |
|
|
|
|
|
config["model"]["params"]["condition_model_path"] = None |
|
config["model"]["params"]["vqgan_model_path"] = None |
|
|
|
|
|
model = instantiate_from_config(config).to(device) |
|
|
|
|
|
_, predicted_sequence, _ = model.celle.sample_text( |
|
text=sequence, |
|
condition=nucleus_image, |
|
image=protein_image, |
|
force_aas=True, |
|
timesteps=1, |
|
temperature=1, |
|
progress=True, |
|
) |
|
|
|
formatted_predicted_sequence = "" |
|
|
|
for i in range(min(len(predicted_sequence), len(sequence))): |
|
if predicted_sequence[i] != sequence[i]: |
|
formatted_predicted_sequence += f"**{predicted_sequence[i]}**" |
|
else: |
|
formatted_predicted_sequence += predicted_sequence[i] |
|
|
|
if len(predicted_sequence) > len(sequence): |
|
formatted_predicted_sequence += f"**{predicted_sequence[len(sequence):]}**" |
|
|
|
print("predicted_sequence:", formatted_predicted_sequence) |
|
|
|
|
|
def run_image_prediction( |
|
sequence_input, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
): |
|
""" |
|
Run Celle model with provided inputs and display results. |
|
|
|
:param sequence: Path to sequence file |
|
:param nucleus_image_path: Path to nucleus image |
|
:param protein_image_path: Path to protein image (optional) |
|
:param model_ckpt_path: Path to model checkpoint |
|
:param model_config_path: Path to model config |
|
""" |
|
|
|
dataset = CellLoader( |
|
sequence_mode="embedding", |
|
vocab="esm2", |
|
split_key="val", |
|
crop_method="center", |
|
resize=600, |
|
crop_size=256, |
|
text_seq_len=1000, |
|
pad_mode="end", |
|
threshold="median", |
|
) |
|
|
|
|
|
if len(sequence_input) == 0: |
|
sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK" |
|
|
|
print("Warning: No sequence provided. Using default sequence for GFP.") |
|
|
|
|
|
sequence = dataset.tokenize_sequence(sequence_input) |
|
|
|
|
|
if not os.path.exists(nucleus_image_path): |
|
|
|
nucleus_image = dataset[0]["nucleus"] |
|
print( |
|
"Warning: No nucleus image provided. Using default nucleus image from dataset." |
|
) |
|
else: |
|
|
|
nucleus_image = process_image(nucleus_image_path) |
|
|
|
|
|
config = OmegaConf.load(model_config_path) |
|
if config["model"]["params"]["ckpt_path"] is None: |
|
config["model"]["params"]["ckpt_path"] = model_ckpt_path |
|
|
|
|
|
config["model"]["params"]["condition_model_path"] = None |
|
config["model"]["params"]["vqgan_model_path"] = None |
|
|
|
|
|
model = instantiate_from_config(config).to(device) |
|
|
|
|
|
_, _, _, predicted_threshold, predicted_heatmap = model.celle.sample( |
|
text=sequence, |
|
condition=nucleus_image, |
|
timesteps=1, |
|
temperature=1, |
|
progress=True, |
|
) |
|
|
|
|
|
predicted_threshold = predicted_threshold.cpu()[0, 0] |
|
predicted_heatmap = predicted_heatmap.cpu()[0, 0] |
|
|
|
|
|
fig, axs = plt.subplots(1, 3 if protein_image_path is None else 4) |
|
axs[0].imshow(nucleus_image) |
|
axs[0].set_title("Nucleus Input") |
|
axs[1].imshow(predicted_threshold) |
|
axs[1].set_title("Predicted Threshold") |
|
if protein_image_path is not None: |
|
protein_image = process_image(protein_image_path) |
|
axs[2].imshow(protein_image) |
|
axs[2].set_title("Protein Image") |
|
axs[-1].imshow(predicted_heatmap) |
|
axs[-1].set_title("Predicted Heatmap") |
|
plt.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Run Celle model with provided inputs." |
|
) |
|
parser.add_argument("--mode", type=str, default="", help="Sequence or Image") |
|
parser.add_argument( |
|
"--sequence", type=str, default="", help="Path to sequence file" |
|
) |
|
parser.add_argument( |
|
"--nucleus_image_path", |
|
type=str, |
|
default="images/nucleus.jpg", |
|
help="Path to nucleus image", |
|
) |
|
parser.add_argument( |
|
"--protein_image_path", |
|
type=str, |
|
default=None, |
|
help="Path to protein image (optional)", |
|
) |
|
parser.add_argument( |
|
"--model_ckpt_path", type=str, required=True, help="Path to model checkpoint" |
|
) |
|
parser.add_argument( |
|
"--model_config_path", type=str, required=True, help="Path to model config" |
|
) |
|
parser.add_argument( |
|
"--device", type=str, default="cpu", required=True, help="device" |
|
) |
|
args = parser.parse_args() |
|
|
|
run_model( |
|
args.mode, |
|
args.sequence, |
|
args.nucleus_image_path, |
|
args.protein_image_path, |
|
args.model_ckpt_path, |
|
args.model_config_path, |
|
args.device |
|
) |
|
|