File size: 5,827 Bytes
b2e58bb
994ebcb
b2e58bb
 
 
 
994ebcb
b2e58bb
 
217c562
7f54610
b2e58bb
 
 
 
 
 
 
 
34489c7
b2e58bb
 
 
 
217c562
b2e58bb
 
 
 
 
 
 
217c562
b2e58bb
 
 
 
 
 
 
 
217c562
b2e58bb
34489c7
 
 
7f54610
b2e58bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217c562
 
 
 
b2e58bb
217c562
 
 
 
 
 
 
 
 
 
 
 
 
 
b2e58bb
217c562
 
 
 
b2e58bb
217c562
 
 
 
 
 
 
 
 
 
 
3f9854c
b2e58bb
3f9854c
 
 
 
 
 
 
 
 
 
 
 
 
 
b2e58bb
 
 
 
3f9854c
 
 
b2e58bb
 
 
3f9854c
 
b2e58bb
 
 
3f9854c
 
b2e58bb
 
 
3f9854c
 
 
 
 
 
b2e58bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Set environment variables **before** importing libraries
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress TensorFlow warnings
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'  # Disable oneDNN custom ops
os.environ['PYTHONHASHSEED'] = '0'        # Ensure reproducibility
os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'  # Reduce RNG reliance

# Set deterministic behavior for PyTorch
import torch
import numpy as np
import tensorflow as tf

# Fix: Set seeds **before** model initialization
torch.manual_seed(42)
torch.use_deterministic_algorithms(True)
np.random.seed(42)
tf.random.set_seed(42)

# Import remaining libraries
import json
from PIL import Image
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F

# Define GNN model architecture
class GNN(torch.nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(4, 16, add_self_loops=False)
        self.conv2 = GCNConv(16, 32, add_self_loops=False)
        self.fc = torch.nn.Linear(32, 2)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = global_mean_pool(x, batch)
        return self.fc(x)

# Load models
cnn_model_path = "final_cnn_model.h5"
improved_model_path = "improved_model.h5"
gnn_model_path = "higgs_gnn_model_cpu.pth"

# Initialize models as None
cnn_model = None
improved_model = None
gnn_model = None

# Load models with error handling
try:
    cnn_model = tf.keras.models.load_model(cnn_model_path)
    print(f"CNN model loaded successfully from {cnn_model_path}")
except Exception as e:
    print(f"Error loading CNN model: {e}")

try:
    improved_model = tf.keras.models.load_model(improved_model_path)
    print(f"Improved model loaded successfully from {improved_model_path}")
except Exception as e:
    print(f"Error loading Improved model: {e}")

try:
    gnn_model = GNN()
    gnn_model.load_state_dict(torch.load(gnn_model_path, map_location='cpu'))
    gnn_model.eval()
except Exception as e:
    print(f"Error loading GNN model: {e}")

# Prediction functions (unchanged from original)
def predict_image(image):
    if image is None:
        return "Error: No image provided"
    if cnn_model is None:
        return "Error: CNN model failed to load. Please check the model file."
    try:
        image = image.convert("L").resize((25, 25))
        image_array = np.array(image) / 255.0
        image_array = np.expand_dims(image_array, axis=(0, -1))
        prediction = cnn_model.predict(image_array)[0][0]
        label = "Signal" if prediction > 0.5 else "Background"
        return f"CNN Prediction: {label} ({prediction * 100:.2f}%)"
    except Exception as e:
        return f"Error processing image: {str(e)}"

def predict_numerical(data):
    if not data:
        return "Error: No numerical data provided"
    if improved_model is None:
        return "Error: Improved model failed to load. Please check the model file."
    try:
        input_data = np.array([float(x) for x in data.split(",")], dtype=np.float32)
        expected_features = improved_model.input_shape[1]
        if len(input_data) != expected_features:
            return f"Error: Expected {expected_features} features, got {len(input_data)}."
        input_data = input_data.reshape(1, expected_features)
        prediction = improved_model.predict(input_data)
        predicted_class = np.argmax(prediction, axis=1)[0]
        confidence = prediction[0][predicted_class]
        return f"Improved Model Prediction: Class {predicted_class} (Confidence: {confidence:.4f})"
    except Exception as e:
        return f"Error processing numerical data: {str(e)}"

def predict_graph(graph_json):
    if not graph_json:
        return "Error: No graph JSON provided"
    if gnn_model is None:
        return "Error: GNN model failed to load"
    try:
        graph_data = json.loads(graph_json)
        x = torch.tensor(graph_data['x'], dtype=torch.float32)
        edge_index = torch.tensor(graph_data['edge_index'], dtype=torch.long)
        data = Data(x=x, edge_index=edge_index)
        data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
        with torch.no_grad():
            out = gnn_model(data)
            prob = F.softmax(out, dim=1)[0][1].item()
        return f"GNN Prediction: Signal Probability = {prob:.4f}"
    except Exception as e:
        return f"Error processing graph: {str(e)}"

# Gradio Interface
import gradio as gr

with gr.Blocks(title="Multi-Model Gradio Interface") as demo:
    gr.Markdown("## Multi-Model Prediction Interface")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image for CNN")
            cnn_output = gr.Textbox(label="CNN Result")
            image_button = gr.Button("Predict Image")
        
        with gr.Column():
            num_input = gr.Textbox(label="Numerical Data (comma-separated) for Improved Model")
            improved_output = gr.Textbox(label="Improved Model Result")
            num_button = gr.Button("Predict Numerical")
        
        with gr.Column():
            graph_input = gr.Textbox(label="Graph JSON for GNN")
            gnn_output = gr.Textbox(label="GNN Result")
            graph_button = gr.Button("Predict Graph")
    
    image_button.click(predict_image, inputs=image_input, outputs=cnn_output)
    num_button.click(predict_numerical, inputs=num_input, outputs=improved_output)
    graph_button.click(predict_graph, inputs=graph_input, outputs=gnn_output)

if __name__ == "__main__":
    demo.launch(debug=True)