import gradio as gr import joblib from concurrent.futures import ThreadPoolExecutor from transformers import AutoTokenizer, AutoModel, EsmModel import torch import numpy as np import random import tensorflow as tf import os from keras.layers import TFSMLayer print(f"TensorFlow Version: {tf.__version__}") base_dir = "." # Set random seed SEED = 42 np.random.seed(SEED) random.seed(SEED) torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) # Ensure deterministic behavior torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_model(model_path): print(f"Loading model from {model_path}...") #print(f"Loading model from {model_path} using TFSMLayer...") #return TFSMLayer(model_path, call_endpoint="serving_default") #return tf.keras.models.load_model(model_path) return tf.saved_model.load(model_path) # Load Random Forest models and configurations print("Loading models...") plant_models = { "Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6}, "kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11}, "KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4}, } general_models = { "Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33}, "kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7}, "KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26}, } # Function to generate embeddings def get_embedding(sequence, esm_model_name, layer): print(f"Generating embeddings using {esm_model_name}, Layer {layer}...") tokenizer = AutoTokenizer.from_pretrained(esm_model_name) model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True) # Tokenize the sequence inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024) # Generate embeddings with torch.no_grad(): outputs = model(**inputs) hidden_states = outputs.hidden_states # Retrieve all hidden states embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling return embedding def predict_with_gpflow(model, X): # Convert input to TensorFlow tensor X_tensor = tf.convert_to_tensor(X, dtype=tf.float64) # Get predictions predict_fn = model.predict_f_compiled mean, variance = predict_fn(X_tensor) # Return mean and variance as numpy arrays return mean.numpy().flatten(), variance.numpy().flatten() # Function to predict based on user choice def predict(sequence, prediction_type): # Select the appropriate model set selected_models = plant_models if prediction_type == "Plant-Specific" else general_models def process_target(target): esm_model_name = selected_models[target]["esm_model"] layer = selected_models[target]["layer"] model = selected_models[target]["model"] # Generate embedding embedding = get_embedding(sequence, esm_model_name, layer) if prediction_type == "Plant-Specific": # Random Forest prediction prediction = model.predict(embedding)[0] return target, round(prediction, 2) else: # GPflow prediction mean, variance = predict_with_gpflow(model, embedding) return target, round(mean[0], 2), round(variance[0], 2) # Predict for all targets in parallel with ThreadPoolExecutor() as executor: results = list(executor.map(process_target, selected_models.keys())) # Format results if prediction_type == "Plant-Specific": formatted_results = [ ["Specificity", results[0][1]], ["kcat\u1d9c", results[1][1]], ["K\u1d9c", results[2][1]], ] else: formatted_results = [ ["Specificity", results[0][1], results[0][2]], ["kcat\u1d9c", results[1][1], results[1][2]], ["K\u1d9c", results[2][1], results[2][2]], ] return formatted_results # Define Gradio interface print("Creating Gradio interface...") interface = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Input Protein Sequence"), # Input: Text box for sequence gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"), # Dropdown for selection ], outputs=gr.Dataframe( headers=["Target", "Prediction", "Uncertainty (for General)"], type="array" ), # Output: Table title="Rubisco Kinetics Prediction", description=( "Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). " "Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions." ), ) if __name__ == "__main__": interface.launch()