|
|
|
import os |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' |
|
os.environ['PYTHONHASHSEED'] = '0' |
|
os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' |
|
|
|
|
|
import torch |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
torch.manual_seed(42) |
|
torch.use_deterministic_algorithms(True) |
|
np.random.seed(42) |
|
tf.random.set_seed(42) |
|
|
|
|
|
import json |
|
from PIL import Image |
|
from torch_geometric.data import Data |
|
from torch_geometric.nn import GCNConv, global_mean_pool |
|
import torch.nn.functional as F |
|
|
|
|
|
class GNN(torch.nn.Module): |
|
def __init__(self): |
|
super(GNN, self).__init__() |
|
self.conv1 = GCNConv(4, 16, add_self_loops=False) |
|
self.conv2 = GCNConv(16, 32, add_self_loops=False) |
|
self.fc = torch.nn.Linear(32, 2) |
|
|
|
def forward(self, data): |
|
x, edge_index, batch = data.x, data.edge_index, data.batch |
|
x = F.relu(self.conv1(x, edge_index)) |
|
x = F.dropout(x, p=0.5, training=self.training) |
|
x = F.relu(self.conv2(x, edge_index)) |
|
x = F.dropout(x, p=0.5, training=self.training) |
|
x = global_mean_pool(x, batch) |
|
return self.fc(x) |
|
|
|
|
|
cnn_model_path = "final_cnn_model.h5" |
|
improved_model_path = "improved_model.h5" |
|
gnn_model_path = "higgs_gnn_model_cpu.pth" |
|
|
|
|
|
cnn_model = None |
|
improved_model = None |
|
gnn_model = None |
|
|
|
|
|
try: |
|
cnn_model = tf.keras.models.load_model(cnn_model_path) |
|
print(f"CNN model loaded successfully from {cnn_model_path}") |
|
except Exception as e: |
|
print(f"Error loading CNN model: {e}") |
|
|
|
try: |
|
improved_model = tf.keras.models.load_model(improved_model_path) |
|
print(f"Improved model loaded successfully from {improved_model_path}") |
|
except Exception as e: |
|
print(f"Error loading Improved model: {e}") |
|
|
|
try: |
|
gnn_model = GNN() |
|
gnn_model.load_state_dict(torch.load(gnn_model_path, map_location='cpu')) |
|
gnn_model.eval() |
|
except Exception as e: |
|
print(f"Error loading GNN model: {e}") |
|
|
|
|
|
def predict_image(image): |
|
if image is None: |
|
return "Error: No image provided" |
|
if cnn_model is None: |
|
return "Error: CNN model failed to load. Please check the model file." |
|
try: |
|
image = image.convert("L").resize((25, 25)) |
|
image_array = np.array(image) / 255.0 |
|
image_array = np.expand_dims(image_array, axis=(0, -1)) |
|
prediction = cnn_model.predict(image_array)[0][0] |
|
label = "Signal" if prediction > 0.5 else "Background" |
|
return f"CNN Prediction: {label} ({prediction * 100:.2f}%)" |
|
except Exception as e: |
|
return f"Error processing image: {str(e)}" |
|
|
|
def predict_numerical(data): |
|
if not data: |
|
return "Error: No numerical data provided" |
|
if improved_model is None: |
|
return "Error: Improved model failed to load. Please check the model file." |
|
try: |
|
input_data = np.array([float(x) for x in data.split(",")], dtype=np.float32) |
|
expected_features = improved_model.input_shape[1] |
|
if len(input_data) != expected_features: |
|
return f"Error: Expected {expected_features} features, got {len(input_data)}." |
|
input_data = input_data.reshape(1, expected_features) |
|
prediction = improved_model.predict(input_data) |
|
predicted_class = np.argmax(prediction, axis=1)[0] |
|
confidence = prediction[0][predicted_class] |
|
return f"Improved Model Prediction: Class {predicted_class} (Confidence: {confidence:.4f})" |
|
except Exception as e: |
|
return f"Error processing numerical data: {str(e)}" |
|
|
|
def predict_graph(graph_json): |
|
if not graph_json: |
|
return "Error: No graph JSON provided" |
|
if gnn_model is None: |
|
return "Error: GNN model failed to load" |
|
try: |
|
graph_data = json.loads(graph_json) |
|
x = torch.tensor(graph_data['x'], dtype=torch.float32) |
|
edge_index = torch.tensor(graph_data['edge_index'], dtype=torch.long) |
|
data = Data(x=x, edge_index=edge_index) |
|
data.batch = torch.zeros(data.num_nodes, dtype=torch.long) |
|
with torch.no_grad(): |
|
out = gnn_model(data) |
|
prob = F.softmax(out, dim=1)[0][1].item() |
|
return f"GNN Prediction: Signal Probability = {prob:.4f}" |
|
except Exception as e: |
|
return f"Error processing graph: {str(e)}" |
|
|
|
|
|
import gradio as gr |
|
|
|
with gr.Blocks(title="Multi-Model Gradio Interface") as demo: |
|
gr.Markdown("## Multi-Model Prediction Interface") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil", label="Upload Image for CNN") |
|
cnn_output = gr.Textbox(label="CNN Result") |
|
image_button = gr.Button("Predict Image") |
|
|
|
with gr.Column(): |
|
num_input = gr.Textbox(label="Numerical Data (comma-separated) for Improved Model") |
|
improved_output = gr.Textbox(label="Improved Model Result") |
|
num_button = gr.Button("Predict Numerical") |
|
|
|
with gr.Column(): |
|
graph_input = gr.Textbox(label="Graph JSON for GNN") |
|
gnn_output = gr.Textbox(label="GNN Result") |
|
graph_button = gr.Button("Predict Graph") |
|
|
|
image_button.click(predict_image, inputs=image_input, outputs=cnn_output) |
|
num_button.click(predict_numerical, inputs=num_input, outputs=improved_output) |
|
graph_button.click(predict_graph, inputs=graph_input, outputs=gnn_output) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |