Spaces:
Sleeping
Sleeping
import gradio as gr | |
import joblib | |
from concurrent.futures import ThreadPoolExecutor | |
from transformers import AutoTokenizer, AutoModel, EsmModel | |
import torch | |
import numpy as np | |
import random | |
import tensorflow as tf | |
import os | |
from keras.layers import TFSMLayer | |
import pandas as pd | |
base_dir = "." | |
# Set random seed | |
SEED = 42 | |
np.random.seed(SEED) | |
random.seed(SEED) | |
torch.manual_seed(SEED) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(SEED) | |
torch.cuda.manual_seed_all(SEED) | |
# Ensure deterministic behavior | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def load_model(model_path): | |
print(f"Loading model from {model_path}...") | |
return tf.saved_model.load(model_path) | |
print("Loading models...") | |
plant_models = { | |
"Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6}, | |
"kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11}, | |
"KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4}, | |
} | |
general_models = { | |
"Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33}, | |
"kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7}, | |
"KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26}, | |
} | |
# Function to generate embeddings | |
def get_embedding(sequence, esm_model_name, layer): | |
print(f"Generating embeddings using {esm_model_name}, Layer {layer}...") | |
tokenizer = AutoTokenizer.from_pretrained(esm_model_name) | |
model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True) | |
# Tokenize the sequence | |
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024) | |
# Generate embeddings | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
hidden_states = outputs.hidden_states # Retrieve all hidden states | |
embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling | |
# Convert to DataFrame with named columns | |
feature_columns = {f"D{i+1}": embedding[0, i] for i in range(embedding.shape[1])} | |
embedding_df = pd.DataFrame([feature_columns]) | |
print (embedding_df) | |
return embedding_df.values, embedding_df | |
def predict_with_gpflow(model, X): | |
print(model.signatures) | |
# Convert input to TensorFlow tensor | |
X_tensor = tf.convert_to_tensor(X, dtype=tf.float64) | |
print (X_tensor.shape) | |
# Get predictions | |
#predict_fn = model.predict_f_compiled | |
predict_fn = model.signatures["serving_default"] | |
result = predict_fn(Xnew=X_tensor) # Pass Xnew explicitly | |
#mean, variance = predict_fn(Xnew=X_tensor) | |
mean = result["output_0"].numpy() # Adjust output key names if needed | |
variance = result["output_1"].numpy() | |
# Return mean and variance as numpy arrays | |
#return mean.numpy().flatten(), variance.numpy().flatten() | |
return mean.flatten(), variance.flatten() | |
def process_target(target, selected_models, sequence, prediction_type): | |
""" | |
Process a single target for prediction using transformer embeddings and the specified model. | |
""" | |
# Get model and embedding details | |
esm_model_name = selected_models[target]["esm_model"] | |
layer = selected_models[target]["layer"] | |
model = selected_models[target]["model"] | |
# Generate embeddings in the required format | |
embedding, _ = get_embedding(sequence, esm_model_name, layer) | |
embedding = embedding.astype(np.float64) | |
np.save(f"hf_embedding_{target}.npy", embedding) | |
if prediction_type == "Plant-Specific": | |
# Random Forest prediction | |
y_pred = model.predict(embedding)[0] | |
return target, round(y_pred, 2) | |
else: | |
# GPflow prediction | |
print (esm_model_name) | |
print (layer) | |
print (model) | |
y_pred, y_uncertainty = predict_with_gpflow(model, embedding) | |
return target, round(y_pred[0], 2), round(y_uncertainty[0], 2) | |
def predict(sequence, prediction_type): | |
""" | |
Predicts Specificity, kcatC, and KC for the given sequence and prediction type. | |
""" | |
# Select the appropriate model set | |
selected_models = plant_models if prediction_type == "Plant-Specific" else general_models | |
# Predict for all targets in parallel | |
with ThreadPoolExecutor() as executor: | |
results = list( | |
executor.map( | |
lambda target: process_target(target, selected_models, sequence, prediction_type), | |
selected_models.keys() | |
) | |
) | |
# Format results | |
if prediction_type == "Plant-Specific": | |
formatted_results = [ | |
["Specificity", results[0][1]], | |
["kcat\u1d9c", results[1][1]], | |
["K\u1d9c", results[2][1]], | |
] | |
else: | |
formatted_results = [ | |
["Specificity", results[0][1], results[0][2]], | |
["kcat\u1d9c", results[1][1], results[1][2]], | |
["K\u1d9c", results[2][1], results[2][2]], | |
] | |
return formatted_results | |
# Define Gradio interface | |
print("Creating Gradio interface...") | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(label="Input Protein Sequence", | |
value="MSPQTETKASVGFKAGVKEYKLTYYTPEYETKDTDILAAFRVTPQPGVPPEEAGAAVAAESSTGTWTTVWTDGLTSLDRYKGRCYHIEPVPGEETQFIAYVAYPLDLFEEGSVTNMFTSIVGNVFGFKALAALRLEDLRIPPAYTKTFQGPPHGIQVERDKLNKYGRPLLGCTIKPKLGLSAKNYGRAVYECLRGGLDFTKDDENVNSQPFMRWRDRFLFCAEAIYKSQAETGEIKGHYLNATAGTCEEMIKRAVFARELGVPIVMHDYLTGGFTANTSLSHYCRDNGLLLHIHRAMHAVIDRQKNHGMHFRVLAKALRLSGGDHIHAGTVVGKLEGDRESTLGFVDLLRDDYVEKDRSRGIFFTQDWVSLPGVLPVASGGIHVWHMPALTEIFGDDSVLQFGGGTLGHPWGNAPGAVANRVALEACVQARNEGRDLAVEGNEIIREACKWSPELAAACEVWKEITFNFPTIDKLDGQE", | |
lines=10, | |
), # Input: Text box for sequence | |
gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"), # Dropdown for selection | |
], | |
outputs=gr.Dataframe( | |
headers=["Target", "Prediction", "Uncertainty (for General)"], | |
type="array" | |
), # Output: Table | |
title="Rubisco Kinetics Prediction", | |
description=( | |
"Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). " | |
"Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions." | |
), | |
) | |
if __name__ == "__main__": | |
interface.launch() | |