EmaadKhwaja commited on
Commit
fbe05c1
1 Parent(s): 243afec

enable queue

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. prediction.py +0 -130
app.py CHANGED
@@ -124,4 +124,4 @@ with gr.Blocks() as demo:
124
 
125
  button.click(gradio_demo, inputs, outputs)
126
 
127
- demo.launch()
 
124
 
125
  button.click(gradio_demo, inputs, outputs)
126
 
127
+ demo.launch(enable_queue=True)
prediction.py CHANGED
@@ -1,138 +1,8 @@
1
- import argparse
2
- import torch
3
  import os
4
  os.chdir('..')
5
  from dataloader import CellLoader
6
- from matplotlib import pyplot as plt
7
  from celle_main import instantiate_from_config
8
  from omegaconf import OmegaConf
9
- from celle.utils import process_image
10
-
11
- def run_model(mode, sequence,
12
- nucleus_image_path,
13
- protein_image_path,
14
- model_ckpt_path,
15
- model_config_path,
16
- device):
17
- if mode == "image":
18
- run_image_prediction(
19
- sequence,
20
- nucleus_image_path,
21
- protein_image_path,
22
- model_ckpt_path,
23
- model_config_path,
24
- device
25
- )
26
- elif mode == "sequence":
27
- run_sequence_prediction(
28
- sequence,
29
- nucleus_image_path,
30
- protein_image_path,
31
- model_ckpt_path,
32
- model_config_path,
33
- device
34
- )
35
-
36
- def run_sequence_prediction(
37
- sequence_input,
38
- nucleus_image_path,
39
- protein_image_path,
40
- model_ckpt_path,
41
- model_config_path,
42
- device
43
- ):
44
- """
45
- Run Celle model with provided inputs and display results.
46
-
47
- :param sequence: Path to sequence file
48
- :param nucleus_image_path: Path to nucleus image
49
- :param protein_image_path: Path to protein image (optional)
50
- :param model_ckpt_path: Path to model checkpoint
51
- :param model_config_path: Path to model config
52
- """
53
-
54
- # Instantiate dataset object
55
- dataset = CellLoader(
56
- sequence_mode="embedding",
57
- vocab="esm2",
58
- split_key="val",
59
- crop_method="center",
60
- resize=600,
61
- crop_size=256,
62
- text_seq_len=1000,
63
- pad_mode="end",
64
- threshold="median",
65
- )
66
-
67
- # Check if sequence is provided and valid
68
- if len(sequence_input) == 0:
69
- raise ValueError("Sequence must be provided.")
70
-
71
- if "<mask>" not in sequence_input:
72
- print("Warning: Sequence does not contain any masked positions to predict.")
73
-
74
- # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
75
- sequence = dataset.tokenize_sequence(sequence_input)
76
-
77
- # Check if nucleus image path is provided and valid
78
- if not os.path.exists(nucleus_image_path):
79
- # Use default nucleus image from dataset and print warning
80
- nucleus_image_path = 'images/nucleus.jpg'
81
- print(
82
- "Warning: No nucleus image provided. Using default nucleus image from dataset."
83
- )
84
- else:
85
- # Load nucleus image from provided path
86
- nucleus_image = process_image(nucleus_image_path)
87
-
88
- # Check if protein image path is provided and valid
89
- if not os.path.exists(protein_image_path):
90
- # Use default nucleus image from dataset and print warning
91
- protein_image_path = 'images/protein.jpg'
92
- print(
93
- "Warning: No nucleus image provided. Using default protein image from dataset."
94
- )
95
- else:
96
- # Load protein image from provided path
97
- protein_image = process_image(protein_image_path)
98
- protein_image = (protein_image > torch.median(protein_image,dim=0))*1.0
99
-
100
- # Load model config and set ckpt_path if not provided in config
101
- config = OmegaConf.load(model_config_path)
102
- if config["model"]["params"]["ckpt_path"] is None:
103
- config["model"]["params"]["ckpt_path"] = model_ckpt_path
104
-
105
- # Set condition_model_path and vqgan_model_path to None
106
- config["model"]["params"]["condition_model_path"] = None
107
- config["model"]["params"]["vqgan_model_path"] = None
108
-
109
- # Instantiate model from config and move to device
110
- model = instantiate_from_config(config).to(device)
111
-
112
- # Sample from model using provided sequence and nucleus image
113
- _, predicted_sequence, _ = model.celle.sample_text(
114
- text=sequence,
115
- condition=nucleus_image,
116
- image=protein_image,
117
- force_aas=True,
118
- timesteps=1,
119
- temperature=1,
120
- progress=True,
121
- )
122
-
123
- formatted_predicted_sequence = ""
124
-
125
- for i in range(min(len(predicted_sequence), len(sequence))):
126
- if predicted_sequence[i] != sequence[i]:
127
- formatted_predicted_sequence += f"**{predicted_sequence[i]}**"
128
- else:
129
- formatted_predicted_sequence += predicted_sequence[i]
130
-
131
- if len(predicted_sequence) > len(sequence):
132
- formatted_predicted_sequence += f"**{predicted_sequence[len(sequence):]}**"
133
-
134
- print("predicted_sequence:", formatted_predicted_sequence)
135
-
136
 
137
  def run_image_prediction(
138
  sequence_input,
 
 
 
1
  import os
2
  os.chdir('..')
3
  from dataloader import CellLoader
 
4
  from celle_main import instantiate_from_config
5
  from omegaconf import OmegaConf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def run_image_prediction(
8
  sequence_input,