theunknowntechie1's picture
Update app.py
b2e58bb verified
# Set environment variables **before** importing libraries
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow warnings
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # Disable oneDNN custom ops
os.environ['PYTHONHASHSEED'] = '0' # Ensure reproducibility
os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' # Reduce RNG reliance
# Set deterministic behavior for PyTorch
import torch
import numpy as np
import tensorflow as tf
# Fix: Set seeds **before** model initialization
torch.manual_seed(42)
torch.use_deterministic_algorithms(True)
np.random.seed(42)
tf.random.set_seed(42)
# Import remaining libraries
import json
from PIL import Image
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F
# Define GNN model architecture
class GNN(torch.nn.Module):
def __init__(self):
super(GNN, self).__init__()
self.conv1 = GCNConv(4, 16, add_self_loops=False)
self.conv2 = GCNConv(16, 32, add_self_loops=False)
self.fc = torch.nn.Linear(32, 2)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.conv2(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = global_mean_pool(x, batch)
return self.fc(x)
# Load models
cnn_model_path = "final_cnn_model.h5"
improved_model_path = "improved_model.h5"
gnn_model_path = "higgs_gnn_model_cpu.pth"
# Initialize models as None
cnn_model = None
improved_model = None
gnn_model = None
# Load models with error handling
try:
cnn_model = tf.keras.models.load_model(cnn_model_path)
print(f"CNN model loaded successfully from {cnn_model_path}")
except Exception as e:
print(f"Error loading CNN model: {e}")
try:
improved_model = tf.keras.models.load_model(improved_model_path)
print(f"Improved model loaded successfully from {improved_model_path}")
except Exception as e:
print(f"Error loading Improved model: {e}")
try:
gnn_model = GNN()
gnn_model.load_state_dict(torch.load(gnn_model_path, map_location='cpu'))
gnn_model.eval()
except Exception as e:
print(f"Error loading GNN model: {e}")
# Prediction functions (unchanged from original)
def predict_image(image):
if image is None:
return "Error: No image provided"
if cnn_model is None:
return "Error: CNN model failed to load. Please check the model file."
try:
image = image.convert("L").resize((25, 25))
image_array = np.array(image) / 255.0
image_array = np.expand_dims(image_array, axis=(0, -1))
prediction = cnn_model.predict(image_array)[0][0]
label = "Signal" if prediction > 0.5 else "Background"
return f"CNN Prediction: {label} ({prediction * 100:.2f}%)"
except Exception as e:
return f"Error processing image: {str(e)}"
def predict_numerical(data):
if not data:
return "Error: No numerical data provided"
if improved_model is None:
return "Error: Improved model failed to load. Please check the model file."
try:
input_data = np.array([float(x) for x in data.split(",")], dtype=np.float32)
expected_features = improved_model.input_shape[1]
if len(input_data) != expected_features:
return f"Error: Expected {expected_features} features, got {len(input_data)}."
input_data = input_data.reshape(1, expected_features)
prediction = improved_model.predict(input_data)
predicted_class = np.argmax(prediction, axis=1)[0]
confidence = prediction[0][predicted_class]
return f"Improved Model Prediction: Class {predicted_class} (Confidence: {confidence:.4f})"
except Exception as e:
return f"Error processing numerical data: {str(e)}"
def predict_graph(graph_json):
if not graph_json:
return "Error: No graph JSON provided"
if gnn_model is None:
return "Error: GNN model failed to load"
try:
graph_data = json.loads(graph_json)
x = torch.tensor(graph_data['x'], dtype=torch.float32)
edge_index = torch.tensor(graph_data['edge_index'], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)
data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
with torch.no_grad():
out = gnn_model(data)
prob = F.softmax(out, dim=1)[0][1].item()
return f"GNN Prediction: Signal Probability = {prob:.4f}"
except Exception as e:
return f"Error processing graph: {str(e)}"
# Gradio Interface
import gradio as gr
with gr.Blocks(title="Multi-Model Gradio Interface") as demo:
gr.Markdown("## Multi-Model Prediction Interface")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image for CNN")
cnn_output = gr.Textbox(label="CNN Result")
image_button = gr.Button("Predict Image")
with gr.Column():
num_input = gr.Textbox(label="Numerical Data (comma-separated) for Improved Model")
improved_output = gr.Textbox(label="Improved Model Result")
num_button = gr.Button("Predict Numerical")
with gr.Column():
graph_input = gr.Textbox(label="Graph JSON for GNN")
gnn_output = gr.Textbox(label="GNN Result")
graph_button = gr.Button("Predict Graph")
image_button.click(predict_image, inputs=image_input, outputs=cnn_output)
num_button.click(predict_numerical, inputs=num_input, outputs=improved_output)
graph_button.click(predict_graph, inputs=graph_input, outputs=gnn_output)
if __name__ == "__main__":
demo.launch(debug=True)