Update app.py
Browse files
app.py
CHANGED
@@ -28,9 +28,62 @@ def generate_caption(protein, prompt):
|
|
28 |
# f.write('>{}\n'.format("protein_name"))
|
29 |
# f.write('{}\n'.format(protein.strip()))
|
30 |
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
|
31 |
-
esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
print("esm embedding generated")
|
35 |
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
36 |
print("esm embedding processed")
|
|
|
28 |
# f.write('>{}\n'.format("protein_name"))
|
29 |
# f.write('{}\n'.format(protein.strip()))
|
30 |
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
|
31 |
+
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
|
32 |
+
# model=model_esm, alphabet=alphabet,
|
33 |
+
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
|
34 |
+
protein_name='protein_name'
|
35 |
+
protein_seq=protein
|
36 |
+
include='per_tok'
|
37 |
+
repr_layers=[36]
|
38 |
+
truncation_seq_length=1024
|
39 |
+
toks_per_batch=4096
|
40 |
+
|
41 |
+
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
42 |
+
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
43 |
+
data_loader = torch.utils.data.DataLoader(
|
44 |
+
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
45 |
+
)
|
46 |
+
print(f"Read sequences")
|
47 |
+
return_contacts = "contacts" in include
|
48 |
+
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
49 |
+
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
50 |
+
with torch.no_grad():
|
51 |
+
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
52 |
+
print(
|
53 |
+
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
|
54 |
+
)
|
55 |
+
if torch.cuda.is_available():
|
56 |
+
toks = toks.to(device="cuda", non_blocking=True)
|
57 |
+
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
58 |
+
logits = out["logits"].to(device="cpu")
|
59 |
+
representations = {
|
60 |
+
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
61 |
+
}
|
62 |
+
if return_contacts:
|
63 |
+
contacts = out["contacts"].to(device="cpu")
|
64 |
+
for i, label in enumerate(labels):
|
65 |
+
result = {"label": label}
|
66 |
+
truncate_len = min(truncation_seq_length, len(strs[i]))
|
67 |
+
# Call clone on tensors to ensure tensors are not views into a larger representation
|
68 |
+
# See https://github.com/pytorch/pytorch/issues/1995
|
69 |
+
if "per_tok" in include:
|
70 |
+
result["representations"] = {
|
71 |
+
layer: t[i, 1 : truncate_len + 1].clone()
|
72 |
+
for layer, t in representations.items()
|
73 |
+
}
|
74 |
+
if "mean" in include:
|
75 |
+
result["mean_representations"] = {
|
76 |
+
layer: t[i, 1 : truncate_len + 1].mean(0).clone()
|
77 |
+
for layer, t in representations.items()
|
78 |
+
}
|
79 |
+
if "bos" in include:
|
80 |
+
result["bos_representations"] = {
|
81 |
+
layer: t[i, 0].clone() for layer, t in representations.items()
|
82 |
+
}
|
83 |
+
if return_contacts:
|
84 |
+
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
|
85 |
+
esm_emb = result['representations'][36]
|
86 |
+
|
87 |
print("esm embedding generated")
|
88 |
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
89 |
print("esm embedding processed")
|