# Set environment variables **before** importing libraries import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow warnings os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # Disable oneDNN custom ops os.environ['PYTHONHASHSEED'] = '0' # Ensure reproducibility os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' # Reduce RNG reliance # Set deterministic behavior for PyTorch import torch import numpy as np import tensorflow as tf # Fix: Set seeds **before** model initialization torch.manual_seed(42) torch.use_deterministic_algorithms(True) np.random.seed(42) tf.random.set_seed(42) # Import remaining libraries import json from PIL import Image from torch_geometric.data import Data from torch_geometric.nn import GCNConv, global_mean_pool import torch.nn.functional as F # Define GNN model architecture class GNN(torch.nn.Module): def __init__(self): super(GNN, self).__init__() self.conv1 = GCNConv(4, 16, add_self_loops=False) self.conv2 = GCNConv(16, 32, add_self_loops=False) self.fc = torch.nn.Linear(32, 2) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.conv2(x, edge_index)) x = F.dropout(x, p=0.5, training=self.training) x = global_mean_pool(x, batch) return self.fc(x) # Load models cnn_model_path = "final_cnn_model.h5" improved_model_path = "improved_model.h5" gnn_model_path = "higgs_gnn_model_cpu.pth" # Initialize models as None cnn_model = None improved_model = None gnn_model = None # Load models with error handling try: cnn_model = tf.keras.models.load_model(cnn_model_path) print(f"CNN model loaded successfully from {cnn_model_path}") except Exception as e: print(f"Error loading CNN model: {e}") try: improved_model = tf.keras.models.load_model(improved_model_path) print(f"Improved model loaded successfully from {improved_model_path}") except Exception as e: print(f"Error loading Improved model: {e}") try: gnn_model = GNN() gnn_model.load_state_dict(torch.load(gnn_model_path, map_location='cpu')) gnn_model.eval() except Exception as e: print(f"Error loading GNN model: {e}") # Prediction functions (unchanged from original) def predict_image(image): if image is None: return "Error: No image provided" if cnn_model is None: return "Error: CNN model failed to load. Please check the model file." try: image = image.convert("L").resize((25, 25)) image_array = np.array(image) / 255.0 image_array = np.expand_dims(image_array, axis=(0, -1)) prediction = cnn_model.predict(image_array)[0][0] label = "Signal" if prediction > 0.5 else "Background" return f"CNN Prediction: {label} ({prediction * 100:.2f}%)" except Exception as e: return f"Error processing image: {str(e)}" def predict_numerical(data): if not data: return "Error: No numerical data provided" if improved_model is None: return "Error: Improved model failed to load. Please check the model file." try: input_data = np.array([float(x) for x in data.split(",")], dtype=np.float32) expected_features = improved_model.input_shape[1] if len(input_data) != expected_features: return f"Error: Expected {expected_features} features, got {len(input_data)}." input_data = input_data.reshape(1, expected_features) prediction = improved_model.predict(input_data) predicted_class = np.argmax(prediction, axis=1)[0] confidence = prediction[0][predicted_class] return f"Improved Model Prediction: Class {predicted_class} (Confidence: {confidence:.4f})" except Exception as e: return f"Error processing numerical data: {str(e)}" def predict_graph(graph_json): if not graph_json: return "Error: No graph JSON provided" if gnn_model is None: return "Error: GNN model failed to load" try: graph_data = json.loads(graph_json) x = torch.tensor(graph_data['x'], dtype=torch.float32) edge_index = torch.tensor(graph_data['edge_index'], dtype=torch.long) data = Data(x=x, edge_index=edge_index) data.batch = torch.zeros(data.num_nodes, dtype=torch.long) with torch.no_grad(): out = gnn_model(data) prob = F.softmax(out, dim=1)[0][1].item() return f"GNN Prediction: Signal Probability = {prob:.4f}" except Exception as e: return f"Error processing graph: {str(e)}" # Gradio Interface import gradio as gr with gr.Blocks(title="Multi-Model Gradio Interface") as demo: gr.Markdown("## Multi-Model Prediction Interface") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image for CNN") cnn_output = gr.Textbox(label="CNN Result") image_button = gr.Button("Predict Image") with gr.Column(): num_input = gr.Textbox(label="Numerical Data (comma-separated) for Improved Model") improved_output = gr.Textbox(label="Improved Model Result") num_button = gr.Button("Predict Numerical") with gr.Column(): graph_input = gr.Textbox(label="Graph JSON for GNN") gnn_output = gr.Textbox(label="GNN Result") graph_button = gr.Button("Predict Graph") image_button.click(predict_image, inputs=image_input, outputs=cnn_output) num_button.click(predict_numerical, inputs=num_input, outputs=improved_output) graph_button.click(predict_graph, inputs=graph_input, outputs=gnn_output) if __name__ == "__main__": demo.launch(debug=True)