EmaadKhwaja commited on
Commit
64212e0
·
1 Parent(s): 86d2765

update app.py

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. app.py +114 -4
  3. celle/celle.py +1061 -0
  4. celle/utils.py +230 -0
  5. dataloader.py +308 -0
  6. requirements.txt +13 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ env
app.py CHANGED
@@ -1,7 +1,117 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from prediction import run_image_prediction
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from celle.utils import process_image
7
+ from PIL import Image
8
+ from matplotlib import pyplot as plt
9
 
 
 
10
 
11
+ def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
12
+ model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
13
+ config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
14
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
15
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ if 'Finetuned' in model_name:
19
+ dataset = 'OpenCell'
20
+
21
+ else:
22
+ dataset = 'HPA'
23
+
24
+ nucleus_image = process_image(nucleus_image,dataset,'nucleus')
25
+ if protein_image:
26
+ protein_image = process_image(protein_image,dataset,'protein')
27
+ protein_image = protein_image > torch.median(protein_image)
28
+ protein_image = protein_image[0,0]
29
+ protein_image = protein_image*1.0
30
+ else:
31
+ protein_image = torch.ones((256,256))
32
+
33
+
34
+ threshold, heatmap = run_image_prediction(sequence_input = sequence_input,
35
+ nucleus_image = nucleus_image,
36
+ model_ckpt_path=model,
37
+ model_config_path=config,
38
+ device=device)
39
+
40
+ # Plot the heatmap
41
+ plt.imshow(heatmap.cpu(), cmap='rainbow', interpolation = 'bicubic')
42
+ plt.axis('off')
43
+
44
+ # Save the plot to a temporary file
45
+ plt.savefig('temp.png', bbox_inches='tight', dpi = 256)
46
+
47
+ # Open the temporary file as a PIL image
48
+ heatmap = Image.open('temp.png')
49
+
50
+ return T.ToPILImage()(nucleus_image[0,0]), T.ToPILImage()(protein_image), T.ToPILImage()(threshold), heatmap
51
+
52
+
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("Select the prediction model.")
55
+ gr.Markdown("CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF.")
56
+ gr.Markdown("CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells.")
57
+ with gr.Row():
58
+ model_name = gr.Dropdown(['CELL-E_2_HPA_480','CELL-E_2_HPA_Finetuned_480'],
59
+ value='CELL-E_2_HPA_480', label = 'Model Name')
60
+ with gr.Row():
61
+ gr.Markdown("Input the desired amino acid sequence. GFP is shown below by default.")
62
+
63
+ with gr.Row():
64
+ sequence_input = gr.Textbox(value='MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
65
+ label = 'Sequence')
66
+ with gr.Row():
67
+ gr.Markdown("Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger.")
68
+ gr.Markdown("The protein image is optional and is just used for display.")
69
+
70
+ with gr.Row().style(equal_height=True):
71
+ nucleus_image = gr.Image(value = 'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg',
72
+ type='pil',
73
+ label = 'Nucleus Image')
74
+
75
+ protein_image = gr.Image(type='pil', label = 'Protein Image (Optional)')
76
+
77
+ with gr.Row():
78
+ gr.Markdown("Image predictions are show below.")
79
+
80
+ with gr.Row().style(equal_height=True):
81
+ nucleus_image_crop = gr.Image(type='pil',
82
+ label = 'Nucleus Image')
83
+
84
+ protein_threshold_image = gr.Image(type='pil',
85
+ label = 'Protein Threshold Image')
86
+
87
+ predicted_threshold_image = gr.Image(type='pil',
88
+ label = 'Predicted Threshold image')
89
+
90
+ predicted_heatmap = gr.Image(type='pil',
91
+ label = 'Predicted Heatmap')
92
+ with gr.Row():
93
+ button = gr.Button("Run Model")
94
+
95
+ inputs = [model_name,
96
+ sequence_input,
97
+ nucleus_image,
98
+ protein_image]
99
+
100
+ outputs = [nucleus_image_crop,
101
+ protein_threshold_image,
102
+ predicted_threshold_image,
103
+ predicted_heatmap]
104
+
105
+ button.click(gradio_demo, inputs, outputs)
106
+
107
+ examples = [['CELL-E_2_HPA_Finetuned_480',
108
+ 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
109
+ 'images/Proteasome activator complex subunit 3 nucleus.png',
110
+ 'images/Proteasome activator complex subunit 3 protein.png'],
111
+ ['CELL-E_2_HPA_480',
112
+ 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
113
+ 'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg',
114
+ 'images/Armadillo repeat-containing X-linked protein 5 protein.jpg']]
115
+
116
+ # demo = gr.Interface(gradio_demo, inputs, outputs, examples, cache_examples=True, layout = layout)
117
+ demo.launch(share=True)
celle/celle.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary packages and modules
2
+ from math import floor, ceil
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from axial_positional_embedding import AxialPositionalEmbedding
7
+ from einops import rearrange
8
+ from celle.utils import (
9
+ exists,
10
+ always,
11
+ eval_decorator,
12
+ gumbel_sample,
13
+ top_k,
14
+ gamma_func,
15
+ DivideMax,
16
+ )
17
+ from tqdm import tqdm
18
+
19
+ # Import additional modules from within the codebase
20
+ from celle.transformer import Transformer
21
+
22
+
23
+ def generate_mask(gamma_func, batch_size, length, device):
24
+ # Get the number of `True` values in the mask for each batch element
25
+ num_true_values = floor(gamma_func(torch.rand(1)) * length)
26
+
27
+ # Generate a random sample of indices to set to `True` in the mask
28
+ # The number of indices in the sample is determined by `num_true_values`
29
+ indices = (
30
+ torch.rand((batch_size, length), device=device)
31
+ .topk(num_true_values, dim=1)
32
+ .indices
33
+ )
34
+
35
+ # Create a binary mask tensor with `True` values at the sampled indices
36
+ mask = torch.zeros((batch_size, length), dtype=torch.bool, device=device)
37
+ mask.scatter_(dim=1, index=indices, value=True)
38
+
39
+ return mask
40
+
41
+
42
+ def match_batch_size(text, condition, image, batch_size):
43
+ """
44
+ This function ensures all inputs to the sample function have the same batch size.
45
+ """
46
+ if text.shape[0] != batch_size:
47
+ text = text.repeat(batch_size, 1)
48
+
49
+ if condition.shape[0] != batch_size:
50
+ condition = condition.repeat(batch_size, 1)
51
+
52
+ if image.shape[0] != batch_size:
53
+ image = image.repeat(batch_size, 1)
54
+
55
+ return text, condition, image
56
+
57
+
58
+ def calc_unmask_probs(timestep, timesteps, gamma_func):
59
+ if timestep == 1 or timesteps == 1:
60
+ unmask_prob = 1
61
+ else:
62
+ unmask_prob = 1 - gamma_func(timestep)
63
+ return unmask_prob
64
+
65
+
66
+ def calculate_logits(
67
+ input_tokens, input_mask, logits_function, filter_thres, temperature
68
+ ):
69
+ logits, _, _ = logits_function(input_tokens, input_mask, return_encoding=False)
70
+ filtered_logits = top_k(logits, thres=filter_thres)
71
+ sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
72
+
73
+ return logits, sample
74
+
75
+
76
+ def unmask_tokens(
77
+ input_tokens,
78
+ input_mask,
79
+ num_masked_tokens,
80
+ logits,
81
+ sample,
82
+ timestep,
83
+ timesteps,
84
+ gamma,
85
+ filter_func=None,
86
+ pad_token=None,
87
+ mask_token=None,
88
+ force_aas=True,
89
+ ):
90
+ sample = sample.masked_fill(~input_mask.unsqueeze(-1), -torch.inf)
91
+ if filter_func:
92
+ sample = filter_func(
93
+ input_tokens, sample, force_aas, pad_token=pad_token, mask_token=mask_token
94
+ )
95
+ selected_token_probs, selected_tokens = torch.max(sample, dim=-1)
96
+
97
+ unmask_prob = calc_unmask_probs(timestep, timesteps, gamma)
98
+ num_tokens_to_unmask = max(1, ceil(unmask_prob * num_masked_tokens))
99
+
100
+ _, top_k_indices = torch.topk(selected_token_probs, num_tokens_to_unmask, dim=-1)
101
+
102
+ sample_mask = torch.zeros(
103
+ input_tokens.shape, dtype=torch.bool, device=input_tokens.device
104
+ )
105
+ sample_mask.scatter_(dim=1, index=top_k_indices, value=True)
106
+
107
+ unmasked_tokens = torch.where(sample_mask, selected_tokens, input_tokens)
108
+ full_logits = torch.where(
109
+ sample_mask.unsqueeze(-1), logits, torch.zeros_like(logits)
110
+ )
111
+ return unmasked_tokens, full_logits
112
+
113
+
114
+ def suppress_invalid_text_tokens(
115
+ text,
116
+ logits,
117
+ start_token=None,
118
+ end_token=None,
119
+ pad_token=None,
120
+ mask_token=None,
121
+ force_aas=False,
122
+ ):
123
+ # Find the indices of start_token and end_token in tensor text along axis=1
124
+ idx_start = (text == start_token).nonzero(as_tuple=True)[1]
125
+ idx_end = (text == end_token).nonzero(as_tuple=True)[1]
126
+
127
+ # For every position other than the index corresponding to the start index, set the values on the start index of dimension=2 to -torch.inf
128
+ if idx_start.nelement() != start_token:
129
+ try:
130
+ mask = idx_start.unsqueeze(1) != torch.arange(
131
+ logits.size(1), device=text.device
132
+ )
133
+ indices = torch.where(mask)
134
+ logits[indices[0], indices[1], start_token] = -torch.inf
135
+ except:
136
+ pass
137
+
138
+ # else:
139
+ # idx_start = torch.zeros(text.size(0), dtype=torch.long)
140
+
141
+ # Similarly, for every position other than the index corresponding to the end index, set the values on the end index of dimension=2 to -torch.inf
142
+ if idx_end.nelement() != 0:
143
+ try:
144
+ mask = idx_end.unsqueeze(1) != torch.arange(
145
+ logits.size(1), device=text.device
146
+ )
147
+ indices = torch.where(mask)
148
+ logits[indices[0], indices[1], end_token] = -torch.inf
149
+ except:
150
+ pass
151
+
152
+ # else:
153
+ # idx_end = torch.full((text.size(0),), text.size(1) - 1, dtype=torch.long)
154
+
155
+ if pad_token:
156
+ if idx_start.nelement() != 0 and idx_end.nelement() != 0:
157
+ try:
158
+ # For every position between the indices of start_token and end_token, set the values for 1st index of dimension=2 equal to -torch.inf. Any value outside of that range should be set to torch.inf.
159
+ mask = (
160
+ torch.arange(logits.size(1), device=text.device)
161
+ >= idx_start.unsqueeze(1)
162
+ ) & (
163
+ torch.arange(logits.size(1), device=text.device)
164
+ <= idx_end.unsqueeze(1)
165
+ )
166
+
167
+ indices = torch.where(mask)
168
+ logits[indices[0], indices[1], pad_token] = -torch.inf
169
+
170
+ indices = torch.where(~mask)
171
+ logits[indices[0], indices[1], pad_token] = torch.inf
172
+
173
+ except:
174
+ pass
175
+
176
+ elif idx_start.nelement() != 0:
177
+ try:
178
+ mask = torch.arange(
179
+ logits.size(1), device=text.device
180
+ ) < idx_start.unsqueeze(1)
181
+ logits[indices[0], indices[1], pad_token] = torch.inf
182
+ except:
183
+ pass
184
+
185
+ elif idx_end.nelement() != 0:
186
+ try:
187
+ mask = torch.arange(
188
+ logits.size(1), device=text.device
189
+ ) > idx_end.unsqueeze(1)
190
+ logits[indices[0], indices[1], pad_token] = torch.inf
191
+ except:
192
+ pass
193
+
194
+ if force_aas:
195
+ if pad_token:
196
+ logits[:, :, pad_token] = -torch.inf
197
+ logits[:, :, 3] = -torch.inf
198
+ logits[:, :, 29:] = -torch.inf
199
+
200
+ if mask_token:
201
+ logits[:, :, mask_token] = -torch.inf
202
+
203
+ return logits
204
+
205
+
206
+ def detokenize_text(text_embedding, sequence):
207
+ if text_embedding == "esm1b" or text_embedding == "esm2":
208
+ from esm import Alphabet
209
+
210
+ alphabet = (
211
+ Alphabet.from_architecture("ESM-1b").get_batch_converter().alphabet.all_toks
212
+ )
213
+ else:
214
+ assert NameError("Detokenization only available for ESM mdodels")
215
+
216
+ output_seqs = []
217
+
218
+ for batch in sequence:
219
+ converted_seq = [alphabet[idx] for idx in batch]
220
+ converted_seq = "".join(converted_seq)
221
+ output_seqs.append(converted_seq)
222
+
223
+ return output_seqs
224
+
225
+ class ImageEmbedding(nn.Module):
226
+ def __init__(self, num_tokens, dim):
227
+ super(ImageEmbedding, self).__init__()
228
+ self.image_embedding = nn.Embedding(num_tokens, dim)
229
+
230
+ def forward(self, image):
231
+ return self.image_embedding(image)
232
+
233
+
234
+ class ModelExtender(nn.Module):
235
+ def __init__(self, vocab, out_features, fixed_embedding=False):
236
+ super(ModelExtender, self).__init__()
237
+
238
+ # Initialize the model according to the given vocabulary
239
+ self.vocab = vocab
240
+
241
+ if vocab == "esm1b":
242
+ from esm import pretrained
243
+
244
+ self.model, _ = pretrained.esm1b_t33_650M_UR50S()
245
+ self.in_features = 1280
246
+ elif vocab == "esm2":
247
+ from esm import pretrained
248
+
249
+ if out_features == 320:
250
+ self.model, _ = pretrained.esm2_t6_8M_UR50D()
251
+ elif out_features == 480:
252
+ self.model, _ = pretrained.esm2_t12_35M_UR50D()
253
+ elif out_features == 640:
254
+ self.model, _ = pretrained.esm2_t30_150M_UR50D()
255
+ elif out_features == 1280:
256
+ self.model, _ = pretrained.esm2_t33_650M_UR50D()
257
+ elif out_features == 2560:
258
+ self.model, _ = pretrained.esm2_t36_3B_UR50D()
259
+ else:
260
+ self.model, _ = pretrained.esm2_t33_650M_UR50D()
261
+ self.in_features = self.model.embed_dim
262
+
263
+ # Set the number of output features and initialize the scaling layer
264
+ self.out_features = out_features
265
+ self.scale_layer = nn.Linear(self.in_features, self.out_features)
266
+
267
+ # Determine whether to freeze the model's parameters
268
+ self.fixed_embedding = fixed_embedding
269
+ if self.fixed_embedding:
270
+ self.model = self.model.eval()
271
+
272
+ def forward(self, x, **kwargs):
273
+ # If the model's parameters are fixed, use torch.no_grad()
274
+ if self.fixed_embedding:
275
+ with torch.no_grad():
276
+ if self.vocab == "esm1b" or self.vocab == "esm2":
277
+ # Reduce sequence length dimension, get top layer representation tensor
278
+ x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
279
+ "representations"
280
+ ][self.model.num_layers]
281
+ # Tensor shape: (batch_size, hidden_size)
282
+ else:
283
+ # Get top layer representation tensor
284
+ x = self.model(x, **kwargs)[0]
285
+ # Tensor shape: (batch_size, sequence_length, hidden_size)
286
+ else:
287
+ if self.vocab == "esm1b" or self.vocab == "esm2":
288
+ # Reduce sequence length dimension, get top layer representation tensor
289
+ x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
290
+ "representations"
291
+ ][self.model.num_layers]
292
+ # Tensor shape: (batch_size, hidden_size)
293
+ else:
294
+ # Get top layer representation tensor
295
+ x = self.model(x, **kwargs)[0]
296
+ # Tensor shape: (batch_size, sequence_length, hidden_size)
297
+
298
+ # Scale the representation tensor if necessary
299
+ if self.out_features != self.in_features:
300
+ x = self.scale_layer(x)
301
+ # Tensor shape: (batch_size, out_features)
302
+
303
+ return x
304
+
305
+ class CELLE(nn.Module):
306
+ def __init__(
307
+ self,
308
+ *,
309
+ dim,
310
+ vae, # The VAE model used to encode/decode images
311
+ condition_vae=None, # An optional VAE model used to condition the image generation
312
+ num_images=2, # Number of images to generate
313
+ num_text_tokens=30, # Number of tokens in the text vocabulary
314
+ text_seq_len=1000, # Maximum length of input text sequence
315
+ depth=16, # Number of layers in the transformer model
316
+ heads=16, # Number of attention heads
317
+ dim_head=64, # Dimensionality of each attention head
318
+ attn_dropout=0.1, # Dropout rate for attention weights
319
+ ff_dropout=0.1, # Dropout rate for feedforward layers
320
+ attn_types=None, # Types of attention to use in the transformer
321
+ causal=False, # Whether to use causal attention
322
+ loss_cond_weight=1, # Weight of conditioning loss
323
+ loss_img_weight=1, # Weight of image generation loss
324
+ stable=False, # Whether to use divide-by-max normalization in the transformer
325
+ rotary_emb=True, # Whether to use rotary positional embeddings
326
+ text_embedding="esm2", # Text embedding to use (esm1b, esm2)
327
+ fixed_embedding=True, # Whether to fix the text embedding or learn it
328
+ sampling_mode="cosine", # Sampling mode for the VAE
329
+ linear_project=False, # Whether to project embeddings linearly
330
+ **kwargs,
331
+ ):
332
+ super().__init__()
333
+
334
+ # Set the stable flag
335
+ self.stable = stable
336
+
337
+ # If the stable flag is set, initialize the DivideMax layer for normalization
338
+ if stable:
339
+ self.norm_by_max = DivideMax(dim=-1)
340
+
341
+ ### Initializing text parameters ###
342
+
343
+ # Initialize the text and fixed embeddings
344
+ self.text_embedding = text_embedding
345
+ self.fixed_embedding = fixed_embedding
346
+
347
+ # Offset logits index and calculate cross entropy loss
348
+ self.num_text_tokens = num_text_tokens
349
+ self.linear_project = linear_project
350
+
351
+ # Add <BOS> and <EOS> tokens to the beginning and end of text sequences
352
+ if text_embedding.lower() in ("esm1b", "esm2"):
353
+ self.text_seq_len = text_seq_len + 2
354
+ else:
355
+ self.text_seq_len = text_seq_len
356
+
357
+ # Initialize embeddings for <SEP> token
358
+ self.sep_emb = nn.Embedding(1, dim)
359
+
360
+ # Initialize positional embeddings for text sequences and <SEP> token
361
+ self.text_pos_emb = (
362
+ nn.Embedding(self.text_seq_len + 1, dim) if not rotary_emb else always(0)
363
+ ) # +1 for <SEP>
364
+
365
+ ### ###
366
+
367
+ self.num_images = num_images
368
+
369
+ ### Initializing condition parameters ###
370
+
371
+ # Initialize the number of condition tokens, condition sequence length, and condition embedding
372
+ if exists(condition_vae):
373
+ condition_size = condition_vae.image_size
374
+ num_condition_tokens = condition_vae.num_tokens
375
+ self.num_condition_tokens = num_condition_tokens
376
+ condition_fmap_size = condition_vae.image_size // (
377
+ 2**condition_vae.num_layers
378
+ )
379
+ condition_seq_len = condition_fmap_size**2
380
+
381
+ # Initialize ImageEmbedding for condition embedding
382
+ self.condition_emb = ImageEmbedding(num_condition_tokens + 1, dim)
383
+
384
+ # Initialize positional embeddings for condition embedding
385
+ self.condition_pos_emb = (
386
+ AxialPositionalEmbedding(
387
+ dim, axial_shape=(condition_fmap_size, condition_fmap_size)
388
+ )
389
+ if not rotary_emb
390
+ else always(0)
391
+ )
392
+
393
+ else:
394
+ condition_fmap_size = 0
395
+ condition_seq_len = 0
396
+ num_condition_tokens = 0
397
+
398
+ ### ####
399
+
400
+ ### Initializing image parameters ###
401
+
402
+ # Initialize the image size, image token size, and sequence length
403
+ self.image_size = vae.image_size
404
+ num_image_tokens = vae.num_tokens
405
+ image_fmap_size = vae.image_size // (2**vae.num_layers)
406
+ image_seq_len = image_fmap_size**2
407
+ self.image_seq_len = image_seq_len
408
+ self.num_image_tokens = num_image_tokens
409
+
410
+ # Initialize ImageEmbedding and positional embeddings for image embedding
411
+ self.image_emb = ImageEmbedding(num_image_tokens + 1, dim) # +1 for <IM_MASK>
412
+
413
+ self.image_pos_emb = (
414
+ AxialPositionalEmbedding(
415
+ dim, axial_shape=(image_fmap_size, image_fmap_size)
416
+ )
417
+ if not rotary_emb
418
+ else always(0)
419
+ )
420
+
421
+ # Set total sequence length and total tokens
422
+ self.num_condition_tokens = num_condition_tokens
423
+ self.condition_seq_len = condition_seq_len
424
+ # Text Length + <SEP> + Condition Tokens + Image Tokens
425
+ seq_len = self.text_seq_len + 1 + self.condition_seq_len + self.image_seq_len
426
+ total_tokens = (
427
+ num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens + 1
428
+ )
429
+ self.total_tokens = total_tokens
430
+ self.total_seq_len = seq_len
431
+
432
+ # Set the VAE and condition VAE for the model
433
+ self.vae = vae.eval()
434
+ self.condition_vae = condition_vae.eval()
435
+
436
+ ### ###
437
+
438
+ ### Setting discrete ids ###
439
+ # Initialize text embedding based on the given text_embedding parameter
440
+ if text_embedding == "esm1b" or text_embedding == "esm2":
441
+ self.text_mask_token = 32
442
+ self.pad_token = 1
443
+ self.text_emb = ModelExtender(text_embedding, dim, fixed_embedding)
444
+ else:
445
+ raise ValueError("Only ESM models are supported.")
446
+
447
+ # Set token indices for text, condition, and image sequences
448
+ self.sep_token = num_text_tokens
449
+ self.cond_mask_token = num_condition_tokens
450
+ self.image_mask_token = num_image_tokens
451
+
452
+ # Create indices for sequence and logits dimensions
453
+ self.seq_range = torch.arange(seq_len)
454
+ self.logits_range = torch.arange(total_tokens)
455
+
456
+ # Reshape sequence and logits indices
457
+ self.seq_range = rearrange(self.seq_range, "n -> () n ()")
458
+ self.logits_range = rearrange(self.logits_range, "d -> () () d")
459
+
460
+ # Create a mask to exclude invalid token positions from the model output
461
+ # e.g. no image tokens where sequence tokens should be
462
+ logits_mask = (
463
+ # Mask text tokens beyond text_seq_len and invalid logits_range
464
+ (
465
+ (self.seq_range < self.text_seq_len)
466
+ & (self.logits_range < num_text_tokens)
467
+ & (self.logits_range != self.text_mask_token)
468
+ )
469
+ |
470
+ # Mask [SEP] token after text
471
+ (
472
+ (self.seq_range == self.text_seq_len)
473
+ & (self.logits_range == num_text_tokens)
474
+ )
475
+ |
476
+ # Mask condition tokens beyond text_seq_len+1 ([SEP]) and invalid logits_range
477
+ (
478
+ (self.seq_range >= self.text_seq_len + 1)
479
+ & (self.seq_range < self.text_seq_len + 1 + condition_seq_len)
480
+ & (self.logits_range >= num_text_tokens + 1)
481
+ & (self.logits_range < num_text_tokens + 1 + num_condition_tokens)
482
+ )
483
+ |
484
+ # Mask image tokens beyond num_text_tokens+num_condition_tokens+1
485
+ (
486
+ (self.seq_range >= self.text_seq_len + 1 + condition_seq_len)
487
+ & (self.logits_range >= num_text_tokens + 1 + num_condition_tokens + 1)
488
+ & (
489
+ self.logits_range
490
+ < num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens
491
+ )
492
+ )
493
+ )
494
+
495
+ # Invert the mask
496
+ logits_mask = ~logits_mask
497
+
498
+ # Register the buffer with the logits_mask
499
+ self.register_buffer("logits_mask", logits_mask, persistent=False)
500
+
501
+ ### ###
502
+
503
+ # Initialize the Transformer model with given parameters
504
+ self.transformer = Transformer(
505
+ dim=dim,
506
+ causal=causal,
507
+ seq_len=seq_len,
508
+ depth=depth,
509
+ heads=heads,
510
+ dim_head=dim_head,
511
+ attn_dropout=attn_dropout,
512
+ ff_dropout=ff_dropout,
513
+ image_fmap_size=image_fmap_size + condition_fmap_size,
514
+ num_images=num_images,
515
+ stable=stable,
516
+ rotary_emb=rotary_emb,
517
+ )
518
+
519
+ # Initialize the linear layers for converting transformer output to logits
520
+ self.to_logits = nn.Sequential(
521
+ nn.LayerNorm(dim),
522
+ nn.Linear(dim, self.total_tokens),
523
+ )
524
+
525
+ # Set instance variables for weights and critic
526
+ self.loss_img_weight = loss_img_weight
527
+ self.loss_cond_weight = loss_cond_weight
528
+ self.gamma = gamma_func(sampling_mode)
529
+
530
+ def embed_and_transform(self, inputs, masks, return_encoding=False):
531
+ text, condition, image = inputs
532
+ device = text.device
533
+ text_mask, _, image_mask = masks
534
+
535
+ text_labels = text.clone()
536
+ text = torch.where(
537
+ text_mask, self.text_mask_token * torch.ones_like(text, device=device), text
538
+ )
539
+
540
+ tokens = self.text_emb(text)
541
+
542
+ # Add SEP token
543
+
544
+ sep_token_emb = self.sep_emb(
545
+ torch.zeros((tokens.shape[0], 1), dtype=torch.long, device=device)
546
+ )
547
+ tokens = torch.cat((tokens, sep_token_emb), dim=1)
548
+ tokens += self.text_pos_emb(torch.arange(text.shape[1] + 1, device=device))
549
+
550
+ with torch.no_grad():
551
+ if self.linear_project:
552
+ b = condition.shape[0]
553
+ condition, _, [_, _, condition_labels] = self.condition_vae.encode(
554
+ condition
555
+ )
556
+ condition_labels = rearrange(condition_labels, "(b n) -> b n", b=b)
557
+
558
+ else:
559
+ condition_labels = condition
560
+ if condition.dtype == torch.float:
561
+ condition_labels = self.condition_vae.get_codebook_indices(
562
+ condition
563
+ )
564
+ condition = condition_labels.clone()
565
+
566
+ condition_emb = self.condition_emb(condition)
567
+ condition_emb += self.condition_pos_emb(condition_emb)
568
+ tokens = torch.cat((tokens, condition_emb), dim=1)
569
+
570
+ with torch.no_grad():
571
+ if self.linear_project:
572
+ b = image.shape[0]
573
+ image, _, [_, _, image_labels] = self.vae.encode(image)
574
+ image_labels = rearrange(image_labels, "(b n) -> b n", b=b)
575
+
576
+ else:
577
+ image_labels = image
578
+ if image.dtype == torch.float:
579
+ image_labels = self.vae.get_codebook_indices(image)
580
+ image = torch.where(
581
+ image_mask,
582
+ self.image_mask_token
583
+ * torch.ones_like(image_labels, device=device),
584
+ image_labels,
585
+ )
586
+
587
+ image_emb = self.image_emb(image)
588
+
589
+ image_emb += self.image_pos_emb(image_emb)
590
+ tokens = torch.cat((tokens, image_emb), dim=1)
591
+
592
+ if self.stable:
593
+ alpha = 0.1
594
+ tokens = tokens * alpha + tokens.detach() * (1 - alpha)
595
+
596
+ out = self.transformer(tokens)
597
+
598
+ if self.stable:
599
+ out = self.norm_by_max(out)
600
+
601
+ logits = self.to_logits(out)
602
+
603
+ max_neg_value = -torch.finfo(logits.dtype).max
604
+ logits.masked_fill_(self.logits_mask, max_neg_value)
605
+
606
+ if return_encoding:
607
+ return logits, out, [text_labels, condition_labels, image_labels]
608
+ else:
609
+ return logits, None, [text_labels, condition_labels, image_labels]
610
+
611
+ def forward(
612
+ self,
613
+ text,
614
+ condition=None,
615
+ image=None,
616
+ return_loss=False,
617
+ return_encoding=False,
618
+ ):
619
+ batch_size, device = text.shape[0], text.device
620
+
621
+ # Check that image is supplied when training
622
+ assert exists(image), "when training, image must be supplied"
623
+
624
+ # Check that image dimensions match the expected dimensions
625
+ assert tuple(image.shape[1:]) == (
626
+ self.vae.channels,
627
+ self.image_size,
628
+ self.image_size,
629
+ ), f"invalid image of dimensions {image.shape} passed in during training"
630
+
631
+ # Generate masks for text, condition, and image
632
+
633
+ # text_mask = generate_mask(self.gamma, batch_size, self.text_seq_len, device)
634
+
635
+ text_mask = generate_mask(
636
+ gamma_func("scaled-cosine"), batch_size, self.text_seq_len, device
637
+ )
638
+
639
+ image_mask = generate_mask(self.gamma, batch_size, self.image_seq_len, device)
640
+
641
+ # Embed and transform inputs
642
+ logits, _, labels = self.embed_and_transform(
643
+ [text, condition, image],
644
+ [text_mask, None, image_mask],
645
+ return_encoding,
646
+ device,
647
+ )
648
+
649
+ # If not returning loss, return the logits
650
+ if not return_loss:
651
+ return logits
652
+
653
+ # Separate labels
654
+ text, condition, image = labels
655
+
656
+ # Add SEP token to end of text label
657
+ sep_token = torch.tensor(self.sep_token, device=device).repeat(
658
+ labels.shape[0], 1
659
+ )
660
+ labels = torch.cat([labels, sep_token], dim=1)
661
+
662
+ # If condition exists and condition vae is defined, add the condition to the labels
663
+ if exists(condition) and exists(self.condition_vae):
664
+ offsetted_condition = condition + self.num_text_tokens + 1
665
+ labels = torch.cat((labels, offsetted_condition), dim=1)
666
+
667
+ # Add image to the labels
668
+ offsetted_image = (
669
+ image + self.num_text_tokens + 1 + self.num_condition_tokens + 1
670
+ )
671
+ labels = torch.cat((labels, offsetted_image), dim=1)
672
+
673
+ # Rearrange logits for cross-entropy loss calculation
674
+ # Logits size: (batch_size, vocab_size, total_seq_len)
675
+ # Labels size: (batch_size, total_seq_len)
676
+ logits = rearrange(logits, "b n c -> b c n")
677
+
678
+ # Calculate cross-entropy loss for text and image
679
+ loss_text = F.cross_entropy(
680
+ logits[:, :, : self.text_seq_len],
681
+ labels[:, : self.text_seq_len],
682
+ reduction="none",
683
+ )[text_mask].mean()
684
+
685
+ loss_img = F.cross_entropy(
686
+ logits[:, :, self.text_seq_len + 1 + self.condition_seq_len :],
687
+ labels[:, self.text_seq_len + 1 + self.condition_seq_len :],
688
+ reduction="none",
689
+ )[image_mask].mean()
690
+
691
+ # Calculate total loss
692
+ loss = (loss_text + self.loss_img_weight * loss_img) / (
693
+ self.loss_img_weight + 1
694
+ )
695
+
696
+ loss_dict = {
697
+ "loss_text": loss_text,
698
+ # "loss_cond": loss_cond,
699
+ "loss_img": loss_img,
700
+ "loss": torch.nan_to_num(loss, 0.0, 0.0, 0.0),
701
+ }
702
+
703
+ return loss, loss_dict, None
704
+
705
+ def create_tensors(self, text, condition, image):
706
+ """
707
+ This function creates tensors for text, condition, and image when they are not provided as inputs to the sample function.
708
+ """
709
+ device = next(
710
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
711
+ None,
712
+ ).device
713
+
714
+ if not isinstance(text, torch.Tensor):
715
+ text = (
716
+ torch.ones(1, self.text_seq_len, device=device, dtype=torch.long)
717
+ * self.text_mask_token
718
+ )
719
+
720
+ if not isinstance(condition, torch.Tensor):
721
+ condition = (
722
+ torch.ones(1, self.condition_seq_len, device=device, dtype=torch.long)
723
+ * self.cond_mask_token
724
+ )
725
+ else:
726
+ with torch.no_grad():
727
+ condition = self.condition_vae.get_codebook_indices(condition)
728
+
729
+ if not isinstance(image, torch.Tensor):
730
+ image = (
731
+ torch.ones(1, self.image_seq_len, device=device, dtype=torch.long)
732
+ * self.image_mask_token
733
+ )
734
+ else:
735
+ with torch.no_grad():
736
+ image = self.vae.get_codebook_indices(image)
737
+
738
+ return text, condition, image
739
+
740
+ @torch.no_grad()
741
+ @eval_decorator
742
+ def sample(
743
+ self,
744
+ text=None,
745
+ condition=None,
746
+ image=None,
747
+ temperature=1.0,
748
+ filter_thres=0.9,
749
+ progress=False,
750
+ timesteps=1,
751
+ force_aas=True,
752
+ ):
753
+ # ensure timesteps is a positive integer
754
+ assert int(timesteps) > 0
755
+ # set model and VAEs to evaluation mode
756
+ self.eval()
757
+ vae = self.vae.eval()
758
+ if progress == True:
759
+ progress = tqdm
760
+ else:
761
+ progress = lambda x: x
762
+
763
+
764
+ # ensure that at least one of text, condition, or image is supplied
765
+ assert (
766
+ isinstance(text, torch.Tensor)
767
+ or isinstance(condition, torch.Tensor)
768
+ or isinstance(image, torch.Tensor)
769
+ ), "some data must be supplied"
770
+
771
+ # convert text, condition, and image to tensors if they aren't already
772
+ text, condition, image = self.create_tensors(text, condition, image)
773
+
774
+ # determine the maximum batch size of the input tensors
775
+ batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
776
+
777
+ # match the batch sizes of text, condition, and image
778
+ text, condition, image = match_batch_size(text, condition, image, batch_size)
779
+
780
+ # determine the device of the tensors
781
+ device = next(
782
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
783
+ None,
784
+ ).device
785
+
786
+ assert text.shape[0] == condition.shape[0] == image.shape[0]
787
+
788
+ # Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
789
+
790
+ # full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
791
+ full_text_logits = torch.zeros(
792
+ batch_size, self.text_seq_len, self.num_text_tokens
793
+ ).to(device)
794
+
795
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
796
+ full_text_logits = full_text_logits.scatter_(
797
+ dim=-1, index=text.unsqueeze(-1), value=1
798
+ )
799
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
800
+ full_image_logits = torch.zeros(
801
+ batch_size, self.image_seq_len, self.num_image_tokens + 1
802
+ ).to(device)
803
+
804
+ # Remove the last token from each image sequence by setting full_image_logits to its first num_image_tokens elements
805
+ full_image_logits = full_image_logits.scatter_(
806
+ dim=-1, index=image.unsqueeze(-1), value=1
807
+ )
808
+
809
+ # cut off mask token
810
+ full_image_logits = full_image_logits[:, :, : self.num_image_tokens]
811
+
812
+ count = 0
813
+
814
+ for timestep in progress(torch.linspace(0, 1, timesteps)):
815
+ # Create masks for the text, condition, and image tensors
816
+ text_mask = text == self.text_mask_token
817
+ cond_mask = condition == self.cond_mask_token
818
+ image_mask = image == self.image_mask_token
819
+
820
+ # Calculate logits and samples using the calculate_logits function
821
+ logits, sample = calculate_logits(
822
+ [text, condition, image],
823
+ [text_mask, cond_mask, image_mask],
824
+ self.embed_and_transform,
825
+ filter_thres,
826
+ temperature,
827
+ )
828
+
829
+ # Calculate the number of masked tokens in the text and image tensors
830
+ num_masked_text_tokens = torch.sum(text_mask, dim=1)[0]
831
+ num_masked_image_tokens = torch.sum(image_mask, dim=1)[0]
832
+
833
+ # If there are masked text tokens, unmask them using unmask_tokens and fill the full text logits tensor with -inf for unmasked tokens
834
+ if num_masked_text_tokens.any() > 0:
835
+ text, full_text_logits = unmask_tokens(
836
+ text,
837
+ text_mask,
838
+ num_masked_text_tokens,
839
+ logits[:, : self.text_seq_len, : self.num_text_tokens],
840
+ sample[:, : self.text_seq_len, : self.num_text_tokens],
841
+ timestep,
842
+ timesteps,
843
+ self.gamma,
844
+ suppress_invalid_text_tokens,
845
+ self.pad_token,
846
+ self.text_mask_token,
847
+ force_aas=force_aas,
848
+ )
849
+ full_text_logits = full_text_logits.masked_fill(
850
+ ~text_mask.unsqueeze(-1), -torch.inf
851
+ )
852
+
853
+ # If there are masked image tokens, unmask them using unmask_tokens and fill the full image logits tensor with -inf for unmasked tokens
854
+ if num_masked_image_tokens > 0:
855
+ image, full_image_logits = unmask_tokens(
856
+ image,
857
+ image_mask,
858
+ num_masked_image_tokens,
859
+ logits[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
860
+ sample[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
861
+ timestep,
862
+ timesteps,
863
+ self.gamma,
864
+ )
865
+ full_text_logits = full_text_logits.masked_fill(
866
+ ~text_mask.unsqueeze(-1), -torch.inf
867
+ )
868
+
869
+ # Generate heatmap
870
+ with torch.no_grad():
871
+ # Normalize full image logits tensor
872
+ full_image_logits /= torch.max(
873
+ torch.abs(full_image_logits), dim=-1, keepdim=True
874
+ ).values
875
+
876
+ # Apply quantize embedding to full image logits tensor
877
+ full_image_logits = torch.matmul(
878
+ full_image_logits, self.vae.model.quantize.embedding.weight
879
+ )
880
+
881
+ # Rearrange full image logits tensor
882
+ h = int(self.image_seq_len**0.5)
883
+ full_image_logits = rearrange(
884
+ full_image_logits, "b (h w) c -> b c h w", h=h
885
+ )
886
+
887
+ # Decode full image logits tensor
888
+ full_image_logits = self.vae.model.decode(full_image_logits)
889
+
890
+ # Add clipping to full image logits tensor
891
+ max_val = torch.max(full_image_logits.view(batch_size, -1), dim=-1)[0]
892
+ min_val = torch.min(full_image_logits.view(batch_size, -1), dim=-1)[0]
893
+ full_image_logits += torch.clip(1 - max_val, 0, float("inf")).view(
894
+ batch_size, 1, 1, 1
895
+ )
896
+ full_image_logits += torch.clip(0 - min_val, float("-inf"), 0).view(
897
+ batch_size, 1, 1, 1
898
+ )
899
+
900
+ # Clip full image logits tensor values to the range [0, 1]
901
+ full_image_logits = torch.clip(full_image_logits, 0, 1)
902
+
903
+ # Return text tensor, detokenized text tensor, full text logits tensor,
904
+ # binary image tensor, and full image logits tensor
905
+ return (
906
+ text,
907
+ detokenize_text(self.text_embedding, text),
908
+ full_text_logits,
909
+ 1.0 * (vae.decode(image) > 0.5),
910
+ full_image_logits,
911
+ )
912
+
913
+ @torch.no_grad()
914
+ @eval_decorator
915
+ def sample_text(
916
+ self,
917
+ text=False,
918
+ condition=False,
919
+ image=False,
920
+ temperature=1.0,
921
+ filter_thres=0.9,
922
+ progress=False,
923
+ n_unmask=1,
924
+ place_amino=True,
925
+ force_aas=False,
926
+ ):
927
+ # set model and VAEs to evaluation mode
928
+ self.eval()
929
+
930
+ # ensure that at least one of text, condition, or image is supplied
931
+ assert (
932
+ isinstance(text, torch.Tensor)
933
+ or isinstance(condition, torch.Tensor)
934
+ or isinstance(image, torch.Tensor)
935
+ ), "some data must be supplied"
936
+
937
+ # convert text, condition, and image to tensors if they aren't already
938
+ text, condition, image = self.create_tensors(text, condition, image)
939
+
940
+ # determine the maximum batch size of the input tensors
941
+ batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
942
+
943
+ # match the batch sizes of text, condition, and image
944
+ text, condition, image = match_batch_size(text, condition, image, batch_size)
945
+
946
+ # determine the device of the tensors
947
+ device = next(
948
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
949
+ None,
950
+ ).device
951
+
952
+ assert text.shape[0] == condition.shape[0] == image.shape[0]
953
+
954
+ # Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
955
+
956
+ # full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
957
+ full_text_logits = torch.zeros(
958
+ batch_size, self.text_seq_len, self.num_text_tokens
959
+ ).to(device)
960
+
961
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
962
+ full_text_logits = full_text_logits.scatter_(
963
+ dim=-1, index=text.unsqueeze(-1), value=1
964
+ )
965
+
966
+ text_mask = text == self.text_mask_token
967
+ cond_mask = condition == self.cond_mask_token
968
+ image_mask = image == self.image_mask_token
969
+
970
+ mask_indices = text_mask.nonzero()
971
+ non_mask_indices = (~text_mask).nonzero()
972
+
973
+ # figure out the center of the amino acids to determine generation direction
974
+ central_protein_index = torch.tensor(
975
+ [
976
+ torch.median(
977
+ non_mask_indices[torch.where(non_mask_indices[:, 0] == idx)][:, -1]
978
+ )
979
+ for idx in range(batch_size)
980
+ ]
981
+ )
982
+
983
+ count = 1
984
+
985
+ run_mask = text_mask
986
+ if progress:
987
+ pbar = progress(total=torch.sum(run_mask).item())
988
+ while torch.sum(run_mask) > 0:
989
+ logits, sample = calculate_logits(
990
+ [text, condition, image],
991
+ [text_mask, cond_mask, image_mask],
992
+ self.embed_and_transform,
993
+ filter_thres,
994
+ temperature,
995
+ )
996
+
997
+ # sub_sample: [batch_size ,text_seq_len ,num_text_tokens]
998
+ sub_sample = sample[:, : self.text_seq_len, : self.num_text_tokens]
999
+ sub_sample = sub_sample.masked_fill(~text_mask.unsqueeze(-1), -torch.inf)
1000
+ sub_sample = suppress_invalid_text_tokens(
1001
+ text, sub_sample, 0, 2, self.pad_token, self.text_mask_token, force_aas
1002
+ )
1003
+ # calculate % to unmasked
1004
+ # get most likely token and probability for each position
1005
+
1006
+ for idx in range(batch_size):
1007
+ selected_mask_indices = mask_indices[
1008
+ torch.where(mask_indices[:, 0] == idx)
1009
+ ][:, -1]
1010
+
1011
+ # Generate to the left
1012
+ if selected_mask_indices[-count] < central_protein_index[idx]:
1013
+ unmask_index = selected_mask_indices[-count]
1014
+ left_sample = max(0, (unmask_index + 1) - n_unmask)
1015
+ right_sample = min(unmask_index + 1, self.text_seq_len - 1)
1016
+ central_protein_index[idx] = max(
1017
+ 0, central_protein_index[idx] - 0.5 * n_unmask
1018
+ )
1019
+
1020
+ # Generate to the right
1021
+ elif selected_mask_indices[count - 1] > central_protein_index[idx]:
1022
+ unmask_index = selected_mask_indices[count - 1]
1023
+ left_sample = max(0, unmask_index)
1024
+ right_sample = min(unmask_index + n_unmask, self.text_seq_len - 1)
1025
+ central_protein_index[idx] = min(
1026
+ central_protein_index[idx] + 0.5 * n_unmask,
1027
+ self.text_seq_len - 1,
1028
+ )
1029
+
1030
+ # save logits for relevant position
1031
+ full_text_logits[
1032
+ idx, left_sample:right_sample, : self.text_seq_len - 1
1033
+ ] = logits[idx, left_sample:right_sample, : self.num_text_tokens]
1034
+
1035
+ run_mask[idx, left_sample:right_sample] = False
1036
+
1037
+ # you may want to resample the amion acids or calculate marginal probs
1038
+ # if so, set place_amino to false
1039
+ if place_amino:
1040
+ text[idx, left_sample:right_sample] = torch.where(
1041
+ text[idx, left_sample:right_sample] == self.text_mask_token,
1042
+ sub_sample[
1043
+ idx, left_sample:right_sample, : self.num_text_tokens
1044
+ ].argmax(dim=-1),
1045
+ text[idx, left_sample:right_sample],
1046
+ )
1047
+
1048
+ text_mask = run_mask
1049
+
1050
+ count += n_unmask
1051
+
1052
+ if progress:
1053
+ pbar.update(n_unmask)
1054
+ if progress:
1055
+ pbar.close()
1056
+
1057
+ return (
1058
+ text,
1059
+ detokenize_text(self.text_embedding, text),
1060
+ full_text_logits,
1061
+ )
celle/utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image, ImageSequence
4
+ from math import pi
5
+ import torchvision.transforms.functional as TF
6
+
7
+
8
+ # Define helper functions
9
+ def exists(val):
10
+ """Check if a variable exists"""
11
+ return val is not None
12
+
13
+
14
+ def uniq(arr):
15
+ return {el: True for el in arr}.keys()
16
+
17
+
18
+ def default(val, d):
19
+ """If a value exists, return it; otherwise, return a default value"""
20
+ return val if exists(val) else d
21
+
22
+
23
+ def max_neg_value(t):
24
+ return -torch.finfo(t.dtype).max
25
+
26
+
27
+ def cast_tuple(val, depth=1):
28
+ if isinstance(val, list):
29
+ val = tuple(val)
30
+ return val if isinstance(val, tuple) else (val,) * depth
31
+
32
+
33
+ def is_empty(t):
34
+ """Check if a tensor is empty"""
35
+ # Return True if the number of elements in the tensor is zero, else False
36
+ return t.nelement() == 0
37
+
38
+
39
+ def masked_mean(t, mask, dim=1):
40
+ """
41
+ Compute the mean of a tensor, masked by a given mask
42
+
43
+ Args:
44
+ t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim)
45
+ mask (torch.Tensor): mask tensor of shape (batch_size, seq_len)
46
+ dim (int): dimension along which to compute the mean (default=1)
47
+
48
+ Returns:
49
+ torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim)
50
+ """
51
+ t = t.masked_fill(~mask[:, :, None], 0.0)
52
+ return t.sum(dim=1) / mask.sum(dim=1)[..., None]
53
+
54
+
55
+ def set_requires_grad(model, value):
56
+ """
57
+ Set whether or not the model's parameters require gradients
58
+
59
+ Args:
60
+ model (torch.nn.Module): the PyTorch model to modify
61
+ value (bool): whether or not to require gradients
62
+ """
63
+ for param in model.parameters():
64
+ param.requires_grad = value
65
+
66
+
67
+ def eval_decorator(fn):
68
+ """
69
+ Decorator function to evaluate a given function
70
+
71
+ Args:
72
+ fn (callable): function to evaluate
73
+
74
+ Returns:
75
+ callable: the decorated function
76
+ """
77
+
78
+ def inner(model, *args, **kwargs):
79
+ was_training = model.training
80
+ model.eval()
81
+ out = fn(model, *args, **kwargs)
82
+ model.train(was_training)
83
+ return out
84
+
85
+ return inner
86
+
87
+
88
+ def log(t, eps=1e-20):
89
+ """
90
+ Compute the natural logarithm of a tensor
91
+
92
+ Args:
93
+ t (torch.Tensor): input tensor
94
+ eps (float): small value to add to prevent taking the log of 0 (default=1e-20)
95
+
96
+ Returns:
97
+ torch.Tensor: the natural logarithm of the input tensor
98
+ """
99
+ return torch.log(t + eps)
100
+
101
+
102
+ def gumbel_noise(t):
103
+ """
104
+ Generate Gumbel noise
105
+
106
+ Args:
107
+ t (torch.Tensor): input tensor
108
+
109
+ Returns:
110
+ torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor
111
+ """
112
+ noise = torch.zeros_like(t).uniform_(0, 1)
113
+ return -log(-log(noise))
114
+
115
+
116
+ def gumbel_sample(t, temperature=0.9, dim=-1):
117
+ """
118
+ Sample from a Gumbel-softmax distribution
119
+
120
+ Args:
121
+ t (torch.Tensor): input tensor of shape (batch_size, num_classes)
122
+ temperature (float): temperature for the Gumbel-softmax distribution (default=0.9)
123
+ dim (int): dimension along which to sample (default=-1)
124
+
125
+ Returns:
126
+ torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor
127
+ """
128
+ return (t / max(temperature, 1e-10)) + gumbel_noise(t)
129
+
130
+
131
+ def top_k(logits, thres=0.5):
132
+ """
133
+ Return a tensor where all but the top k values are set to negative infinity
134
+
135
+ Args:
136
+ logits (torch.Tensor): input tensor of shape (batch_size, num_classes)
137
+ thres (float): threshold for the top k values (default=0.5)
138
+
139
+ Returns:
140
+ torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity
141
+ """
142
+ num_logits = logits.shape[-1]
143
+ k = max(int((1 - thres) * num_logits), 1)
144
+ val, ind = torch.topk(logits, k)
145
+ probs = torch.full_like(logits, float("-inf"))
146
+ probs.scatter_(-1, ind, val)
147
+ return probs
148
+
149
+
150
+ def gamma_func(mode="cosine", scale=0.15):
151
+ """Return a function that takes a single input r and returns a value based on the selected mode"""
152
+
153
+ # Define a different function based on the selected mode
154
+ if mode == "linear":
155
+ return lambda r: 1 - r
156
+ elif mode == "cosine":
157
+ return lambda r: torch.cos(r * pi / 2)
158
+ elif mode == "square":
159
+ return lambda r: 1 - r**2
160
+ elif mode == "cubic":
161
+ return lambda r: 1 - r**3
162
+ elif mode == "scaled-cosine":
163
+ return lambda r: scale * (torch.cos(r * pi / 2))
164
+ else:
165
+ # Raise an error if the selected mode is not implemented
166
+ raise NotImplementedError
167
+
168
+
169
+ class always:
170
+ """Helper class to always return a given value"""
171
+
172
+ def __init__(self, val):
173
+ self.val = val
174
+
175
+ def __call__(self, x, *args, **kwargs):
176
+ return self.val
177
+
178
+
179
+ class DivideMax(torch.nn.Module):
180
+ def __init__(self, dim):
181
+ super().__init__()
182
+ self.dim = dim
183
+
184
+ def forward(self, x):
185
+ maxes = x.amax(dim=self.dim, keepdim=True).detach()
186
+ return x / maxes
187
+
188
+ def replace_outliers(image, percentile=0.0001):
189
+
190
+ lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile(
191
+ image, 1 - percentile
192
+ )
193
+ mask = (image <= upper_bound) & (image >= lower_bound)
194
+
195
+ valid_pixels = image[mask]
196
+
197
+ image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels))
198
+
199
+ return image
200
+
201
+
202
+ def process_image(image, dataset, image_type=None):
203
+ image = TF.to_tensor(image)[0].unsqueeze(0).unsqueeze(0)
204
+ image /= image.max()
205
+
206
+ if dataset == "HPA":
207
+ if image_type == 'nucleus':
208
+ normalize = (0.0655, 0.0650)
209
+
210
+ elif image_type == 'protein':
211
+ normalize = (0.1732, 0.1208)
212
+
213
+ elif dataset == "OpenCell":
214
+
215
+ if image_type == 'nucleus':
216
+ normalize = (0.0272, 0.0244)
217
+
218
+ elif image_type == 'protein':
219
+ normalize = (0.0486, 0.0671)
220
+
221
+ t_forms = []
222
+
223
+ t_forms.append(transforms.RandomCrop(256))
224
+
225
+ # t_forms.append(transforms.Normalize(normalize[0],normalize[1]))
226
+
227
+
228
+ image = transforms.Compose(t_forms)(image)
229
+
230
+ return image
dataloader.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image, ImageSequence
4
+ import json
5
+ import pandas as pd
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision import transforms
10
+ import torchvision.transforms.functional as TF
11
+
12
+ from celle.utils import replace_outliers
13
+
14
+ def simple_conversion(seq):
15
+ """Create 26-dim embedding"""
16
+ chars = [
17
+ "-",
18
+ "M",
19
+ "R",
20
+ "H",
21
+ "K",
22
+ "D",
23
+ "E",
24
+ "S",
25
+ "T",
26
+ "N",
27
+ "Q",
28
+ "C",
29
+ "U",
30
+ "G",
31
+ "P",
32
+ "A",
33
+ "V",
34
+ "I",
35
+ "F",
36
+ "Y",
37
+ "W",
38
+ "L",
39
+ "O",
40
+ "X",
41
+ "Z",
42
+ "B",
43
+ "J",
44
+ ]
45
+
46
+ nums = range(len(chars))
47
+
48
+ seqs_x = np.zeros(len(seq))
49
+
50
+ for idx, char in enumerate(seq):
51
+
52
+ lui = chars.index(char)
53
+
54
+ seqs_x[idx] = nums[lui]
55
+
56
+ return torch.tensor([seqs_x]).long()
57
+
58
+
59
+ class CellLoader(Dataset):
60
+ """imports mined opencell images with protein sequence"""
61
+
62
+ def __init__(
63
+ self,
64
+ data_csv=None,
65
+ dataset=None,
66
+ split_key=None,
67
+ resize=600,
68
+ crop_size=600,
69
+ crop_method="random",
70
+ sequence_mode="simple",
71
+ vocab="bert",
72
+ threshold="median",
73
+ text_seq_len=0,
74
+ pad_mode="random",
75
+ ):
76
+ self.data_csv = data_csv
77
+ self.dataset = dataset
78
+ self.image_folders = []
79
+ self.crop_method = crop_method
80
+ self.resize = resize
81
+ self.crop_size = crop_size
82
+ self.sequence_mode = sequence_mode
83
+ self.threshold = threshold
84
+ self.text_seq_len = int(text_seq_len)
85
+ self.vocab = vocab
86
+ self.pad_mode = pad_mode
87
+
88
+ if self.sequence_mode == "embedding" or self.sequence_mode == "onehot":
89
+
90
+
91
+ if self.vocab == "esm1b" or self.vocab == "esm2":
92
+ from esm import Alphabet
93
+
94
+ self.tokenizer = Alphabet.from_architecture(
95
+ "ESM-1b"
96
+ ).get_batch_converter()
97
+ self.text_seq_len += 2
98
+
99
+ if data_csv:
100
+
101
+ data = pd.read_csv(data_csv)
102
+
103
+ self.parent_path = os.path.dirname(data_csv).split(data_csv)[0]
104
+
105
+ if split_key == "train":
106
+ self.data = data[data["split"] == "train"]
107
+ elif split_key == "val":
108
+ self.data = data[data["split"] == "val"]
109
+ else:
110
+ self.data = data
111
+
112
+ self.data = self.data.reset_index(drop=True)
113
+
114
+
115
+
116
+ def __len__(self):
117
+ return len(self.data)
118
+
119
+ def __getitem__(
120
+ self,
121
+ idx,
122
+ get_sequence=True,
123
+ get_images=True,
124
+ ):
125
+ if get_sequence and self.text_seq_len > 0:
126
+
127
+ protein_vector = self.get_protein_vector(idx)
128
+
129
+ else:
130
+ protein_vector = torch.zeros((1, 1))
131
+
132
+ if get_images:
133
+
134
+ nucleus, target, threshold = self.get_images(idx, self.dataset)
135
+ else:
136
+ nucleus, target, threshold = torch.zeros((3, 1))
137
+
138
+ data_dict = {
139
+ "nucleus": nucleus.float(),
140
+ "target": target.float(),
141
+ "threshold": threshold.float(),
142
+ "sequence": protein_vector.long(),
143
+ }
144
+
145
+ return data_dict
146
+
147
+ def get_protein_vector(self, idx):
148
+
149
+ if "protein_sequence" not in self.data.columns:
150
+
151
+ metadata = self.retrieve_metadata(idx)
152
+ protein_sequence = metadata["sequence"]
153
+ else:
154
+ protein_sequence = self.data.iloc[idx]["protein_sequence"]
155
+
156
+ protein_vector = self.tokenize_sequence(protein_sequence)
157
+
158
+ return protein_vector
159
+
160
+ def get_images(self, idx, dataset):
161
+
162
+ if dataset == "HPA":
163
+
164
+ nucleus = Image.open(
165
+ os.path.join(
166
+ self.parent_path, self.data.iloc[idx]["nucleus_image_path"]
167
+ )
168
+ )
169
+
170
+ target = Image.open(
171
+ os.path.join(self.parent_path, self.data.iloc[idx]["target_image_path"])
172
+ )
173
+
174
+ nucleus = TF.to_tensor(nucleus)[0]
175
+ target = TF.to_tensor(target)[0]
176
+
177
+ image = torch.stack([nucleus, target], axis=0)
178
+
179
+ normalize = (0.0655, 0.0650), (0.1732, 0.1208)
180
+
181
+ elif dataset == "OpenCell":
182
+ image = Image.open(
183
+ os.path.join(self.parent_path, self.data.iloc[idx]["image_path"])
184
+ )
185
+ nucleus, target = [page.copy() for page in ImageSequence.Iterator(image)]
186
+
187
+ nucleus = replace_outliers(torch.divide(TF.to_tensor(nucleus), 65536))[0]
188
+ target = replace_outliers(torch.divide(TF.to_tensor(target), 65536))[0]
189
+
190
+ image = torch.stack([nucleus, target], axis=0)
191
+
192
+ normalize = (
193
+ (0.0272, 0.0244),
194
+ (0.0486, 0.0671),
195
+ )
196
+
197
+ # # from https://discuss.pytorch.org/t/how-to-apply-same-transform-on-a-pair-of-picture/14914
198
+
199
+ t_forms = [transforms.Resize(self.resize, antialias=None)]
200
+
201
+ if self.crop_method == "random":
202
+
203
+ t_forms.append(transforms.RandomCrop(self.crop_size))
204
+ t_forms.append(transforms.RandomHorizontalFlip(p=0.5))
205
+ t_forms.append(transforms.RandomVerticalFlip(p=0.5))
206
+
207
+ elif self.crop_method == "center":
208
+
209
+ t_forms.append(transforms.CenterCrop(self.crop_size))
210
+
211
+ t_forms.append(transforms.Normalize(normalize[0], normalize[1]))
212
+
213
+ image = transforms.Compose(t_forms)(image)
214
+
215
+ nucleus, target = image
216
+
217
+ nucleus /= torch.abs(nucleus).max()
218
+ target -= target.min()
219
+ target /= target.max()
220
+
221
+ nucleus = nucleus.unsqueeze(0)
222
+ target = target.unsqueeze(0)
223
+
224
+ threshold = target
225
+
226
+ if self.threshold == "mean":
227
+
228
+ threshold = 1.0 * (threshold > (torch.mean(threshold)))
229
+
230
+ elif self.threshold == "median":
231
+
232
+ threshold = 1.0 * (threshold > (torch.median(threshold)))
233
+
234
+ elif self.threshold == "1090_IQR":
235
+
236
+ p10 = torch.quantile(threshold, 0.1, None)
237
+ p90 = torch.quantile(threshold, 0.9, None)
238
+ threshold = torch.clip(threshold, p10, p90)
239
+
240
+ nucleus = torch.nan_to_num(nucleus, 0.0, 1.0, 0.0)
241
+ target = torch.nan_to_num(target, 0.0, 1.0, 0.0)
242
+ threshold = torch.nan_to_num(threshold, 0.0, 1.0, 0.0)
243
+
244
+ return nucleus, target, threshold
245
+
246
+ def retrieve_metadata(self, idx):
247
+ with open(
248
+ os.path.join(self.parent_path, self.data.iloc[idx]["metadata_path"])
249
+ ) as f:
250
+ metadata = json.load(f)
251
+
252
+ return metadata
253
+
254
+ def tokenize_sequence(self, protein_sequence):
255
+
256
+ pad_token = 0
257
+
258
+ if self.sequence_mode == "simple":
259
+ protein_vector = simple_conversion(protein_sequence)
260
+
261
+ elif self.sequence_mode == "center":
262
+ protein_sequence = protein_sequence.center(self.text_seq_length, "-")
263
+ protein_vector = simple_conversion(protein_sequence)
264
+
265
+ elif self.sequence_mode == "alternating":
266
+ protein_sequence = protein_sequence.center(self.text_seq_length, "-")
267
+ protein_sequence = protein_sequence[::18]
268
+ protein_sequence = protein_sequence.center(
269
+ int(self.text_seq_length / 18) + 1, "-"
270
+ )
271
+ protein_vector = simple_conversion(protein_sequence)
272
+
273
+
274
+ elif self.sequence_mode == "embedding":
275
+
276
+ if self.vocab == "esm1b" or self.vocab == "esm2":
277
+ pad_token = 1
278
+ protein_vector = self.tokenizer([("", protein_sequence)])[-1]
279
+
280
+ if protein_vector.shape[-1] < self.text_seq_len:
281
+
282
+ diff = self.text_seq_len - protein_vector.shape[-1]
283
+
284
+ if self.pad_mode == "end":
285
+ protein_vector = torch.nn.functional.pad(
286
+ protein_vector, (0, diff), "constant", pad_token
287
+ )
288
+ elif self.pad_mode == "random":
289
+ split = diff - np.random.randint(0, diff + 1)
290
+
291
+ protein_vector = torch.cat(
292
+ [torch.ones(1, split) * 0, protein_vector], dim=1
293
+ )
294
+
295
+ protein_vector = torch.nn.functional.pad(
296
+ protein_vector, (0, diff - split), "constant", pad_token
297
+ )
298
+
299
+ elif protein_vector.shape[-1] > self.text_seq_len:
300
+ start_int = np.random.randint(
301
+ 0, protein_vector.shape[-1] - self.text_seq_len
302
+ )
303
+
304
+ protein_vector = protein_vector[
305
+ :, start_int : start_int + self.text_seq_len
306
+ ]
307
+
308
+ return protein_vector.long()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ os
2
+ torch
3
+ torchvision
4
+ huggingface_hub
5
+ gradio
6
+ OmegaConf
7
+ axial-positional-embedding
8
+ einops
9
+ rotary_embedding_torch
10
+ fair-esm
11
+ tqdm
12
+ importlib
13
+ pytorch-lightning==1.9.0