Update app.py
Browse files
app.py
CHANGED
@@ -2,23 +2,51 @@ import gradio as gr
|
|
2 |
import sys
|
3 |
import random
|
4 |
import os
|
|
|
|
|
|
|
|
|
5 |
sys.path.append("scripts/")
|
6 |
from foldseek_util import get_struc_seq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Assuming 'predict_stability' is your function that predicts protein stability
|
9 |
-
def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None):
|
10 |
# Check if pdb_file is provided
|
11 |
if pdb_file:
|
12 |
pdb_path = pdb_file.name # Get the path of the uploaded PDB file
|
13 |
os.system("chmod 777 bin/foldseek")
|
14 |
-
|
15 |
-
if not
|
16 |
return "Failed to extract sequence from the PDB file."
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# If sequence is provided directly
|
19 |
if sequence:
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
22 |
else:
|
23 |
return "No valid input provided."
|
24 |
|
@@ -33,6 +61,56 @@ def get_foldseek_seq(pdb_path):
|
|
33 |
return parsed_seqs
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# Gradio Interface
|
37 |
with gr.Blocks() as demo:
|
38 |
gr.Markdown(
|
@@ -47,7 +125,7 @@ with gr.Blocks() as demo:
|
|
47 |
# Model and Organism selection in the same row to avoid layout issues
|
48 |
with gr.Row():
|
49 |
model_choice = gr.Radio(
|
50 |
-
choices=["SaProt", "
|
51 |
label="Select PLTNUM's base model.",
|
52 |
value="SaProt"
|
53 |
)
|
@@ -82,7 +160,7 @@ with gr.Blocks() as demo:
|
|
82 |
gr.Markdown(
|
83 |
"""
|
84 |
### How to Use:
|
85 |
-
- **Select Model**: Choose between 'SaProt' or '
|
86 |
- **Select Organism**: Choose between 'Mouse' or 'Human'.
|
87 |
- **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file.
|
88 |
- **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence.
|
|
|
2 |
import sys
|
3 |
import random
|
4 |
import os
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from transformers import AutoTokenizer
|
9 |
sys.path.append("scripts/")
|
10 |
from foldseek_util import get_struc_seq
|
11 |
+
from utils import seed_everything
|
12 |
+
from models import PLTNUM_PreTrainedModel
|
13 |
+
from datasets import PLTNUMDataset
|
14 |
+
|
15 |
+
class Config:
|
16 |
+
batch_size = 2
|
17 |
+
use_amp = False
|
18 |
+
num_workers = 1
|
19 |
+
max_length = 512
|
20 |
+
used_sequence = "left"
|
21 |
+
padding_side = "right"
|
22 |
+
task = "classification"
|
23 |
+
sequence_col = "sequence"
|
24 |
|
25 |
# Assuming 'predict_stability' is your function that predicts protein stability
|
26 |
+
def predict_stability(cfg, model_choice, organism_choice, pdb_file=None, sequence=None):
|
27 |
# Check if pdb_file is provided
|
28 |
if pdb_file:
|
29 |
pdb_path = pdb_file.name # Get the path of the uploaded PDB file
|
30 |
os.system("chmod 777 bin/foldseek")
|
31 |
+
sequences = get_foldseek_seq(pdb_path)
|
32 |
+
if not sequences:
|
33 |
return "Failed to extract sequence from the PDB file."
|
34 |
+
if model_choice == "SaProt":
|
35 |
+
sequence = sequences[2]
|
36 |
+
else:
|
37 |
+
sequence = sequences[0]
|
38 |
+
|
39 |
+
if organism_choice == "Human":
|
40 |
+
cell_line = "HeLa"
|
41 |
+
else:
|
42 |
+
cell_line = "NIH3T3"
|
43 |
# If sequence is provided directly
|
44 |
if sequence:
|
45 |
+
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
|
46 |
+
cfg.architecture = model_choice
|
47 |
+
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
|
48 |
+
output = predict(cfg, sequence)
|
49 |
+
return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {sequence}..."
|
50 |
else:
|
51 |
return "No valid input provided."
|
52 |
|
|
|
61 |
return parsed_seqs
|
62 |
|
63 |
|
64 |
+
def predict(cfg, sequence):
|
65 |
+
cfg.token_length = 2 if cfg.architecture == "SaProt" else 1
|
66 |
+
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
|
67 |
+
|
68 |
+
if cfg.used_sequence == "both":
|
69 |
+
cfg.max_length += 1
|
70 |
+
|
71 |
+
seed_everything(cfg.seed)
|
72 |
+
|
73 |
+
df = pd.DataFrame({cfg.sequence_col: [sequence]})
|
74 |
+
|
75 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
76 |
+
cfg.model_path, padding_side=cfg.padding_side
|
77 |
+
)
|
78 |
+
cfg.tokenizer = tokenizer
|
79 |
+
|
80 |
+
dataset = PLTNUMDataset(cfg, df, train=False)
|
81 |
+
dataloader = DataLoader(
|
82 |
+
dataset,
|
83 |
+
batch_size=cfg.batch_size,
|
84 |
+
shuffle=False,
|
85 |
+
num_workers=cfg.num_workers,
|
86 |
+
pin_memory=True,
|
87 |
+
drop_last=False,
|
88 |
+
)
|
89 |
+
|
90 |
+
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
|
91 |
+
model.to(cfg.device)
|
92 |
+
|
93 |
+
# predictions = predict_fn(loader, model, cfg)
|
94 |
+
model.eval()
|
95 |
+
predictions = []
|
96 |
+
|
97 |
+
for inputs, _ in dataloader:
|
98 |
+
inputs = inputs.to(cfg.device)
|
99 |
+
with torch.no_grad():
|
100 |
+
with torch.amp.autocast(enabled=cfg.use_amp):
|
101 |
+
preds = (
|
102 |
+
torch.sigmoid(model(inputs))
|
103 |
+
if cfg.task == "classification"
|
104 |
+
else model(inputs)
|
105 |
+
)
|
106 |
+
predictions += preds.cpu().tolist()
|
107 |
+
outputs = {}
|
108 |
+
outputs["raw prediction values"] = predictions
|
109 |
+
outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions]
|
110 |
+
return outputs
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
# Gradio Interface
|
115 |
with gr.Blocks() as demo:
|
116 |
gr.Markdown(
|
|
|
125 |
# Model and Organism selection in the same row to avoid layout issues
|
126 |
with gr.Row():
|
127 |
model_choice = gr.Radio(
|
128 |
+
choices=["SaProt", "ESM2"],
|
129 |
label="Select PLTNUM's base model.",
|
130 |
value="SaProt"
|
131 |
)
|
|
|
160 |
gr.Markdown(
|
161 |
"""
|
162 |
### How to Use:
|
163 |
+
- **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction.
|
164 |
- **Select Organism**: Choose between 'Mouse' or 'Human'.
|
165 |
- **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file.
|
166 |
- **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence.
|