Spaces:
Sleeping
Sleeping
DanielXu0208
commited on
Commit
•
e073b0b
1
Parent(s):
9e56473
Update run_gradio.py
Browse files- run_gradio.py +246 -212
run_gradio.py
CHANGED
@@ -1,215 +1,249 @@
|
|
1 |
-
import
|
2 |
import torch
|
3 |
-
|
4 |
-
|
5 |
-
from torchvision.models import resnet50
|
6 |
-
from sklearn.cluster import KMeans
|
7 |
-
import numpy as np
|
8 |
import os
|
9 |
-
import logging
|
10 |
-
from torchcam.utils import overlay_mask
|
11 |
from PIL import Image
|
12 |
-
from
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
#
|
118 |
-
|
119 |
-
|
120 |
-
#
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
#
|
174 |
-
|
175 |
-
|
176 |
-
#
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
#
|
183 |
-
|
184 |
-
|
185 |
-
#
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
import torch
|
3 |
+
import torchvision
|
4 |
+
import pandas as pd
|
|
|
|
|
|
|
5 |
import os
|
|
|
|
|
6 |
from PIL import Image
|
7 |
+
from utils.experiment_utils import get_model
|
8 |
+
|
9 |
+
# File to store the visitor count
|
10 |
+
visitor_count_file = "visitor_count.txt"
|
11 |
+
|
12 |
+
# Function to update visitor count
|
13 |
+
def update_visitor_count():
|
14 |
+
if os.path.exists(visitor_count_file):
|
15 |
+
with open(visitor_count_file, "r") as file:
|
16 |
+
count = int(file.read())
|
17 |
+
else:
|
18 |
+
count = 0 # Start from zero if no file exists
|
19 |
+
|
20 |
+
# Increment visitor count
|
21 |
+
count += 1
|
22 |
+
|
23 |
+
# Save the updated count back to the file
|
24 |
+
with open(visitor_count_file, "w") as file:
|
25 |
+
file.write(str(count))
|
26 |
+
|
27 |
+
return count
|
28 |
+
|
29 |
+
# Custom flagging logic to save flagged data to a CSV file
|
30 |
+
class CustomFlagging(gr.FlaggingCallback):
|
31 |
+
def __init__(self, dir_name="flagged_data"):
|
32 |
+
self.dir = dir_name
|
33 |
+
self.image_dir = os.path.join(self.dir, "uploaded_images")
|
34 |
+
if not os.path.exists(self.dir):
|
35 |
+
os.makedirs(self.dir)
|
36 |
+
if not os.path.exists(self.image_dir):
|
37 |
+
os.makedirs(self.image_dir)
|
38 |
+
|
39 |
+
# Define setup as a no-op to fulfill abstract class requirement
|
40 |
+
def setup(self, *args, **kwargs):
|
41 |
+
pass
|
42 |
+
|
43 |
+
def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
|
44 |
+
# Extract data
|
45 |
+
classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data
|
46 |
+
|
47 |
+
# Save the uploaded image in the "uploaded_images" folder
|
48 |
+
image_filename = os.path.join(self.image_dir,
|
49 |
+
f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png")
|
50 |
+
image.save(image_filename) # Save image in PNG format
|
51 |
+
|
52 |
+
# Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class
|
53 |
+
data = {
|
54 |
+
"Classification Mode": classification_mode,
|
55 |
+
"Image Path": image_filename, # Save path to image in CSV
|
56 |
+
"Sensing Modality": sensing_modality,
|
57 |
+
"Predicted Class": predicted_class,
|
58 |
+
"Correct Class": correct_class,
|
59 |
+
}
|
60 |
+
|
61 |
+
df = pd.DataFrame([data])
|
62 |
+
csv_file = os.path.join(self.dir, "flagged_data.csv")
|
63 |
+
|
64 |
+
# Append to CSV, or create if it doesn't exist
|
65 |
+
if os.path.exists(csv_file):
|
66 |
+
df.to_csv(csv_file, mode='a', header=False, index=False)
|
67 |
+
else:
|
68 |
+
df.to_csv(csv_file, mode='w', header=True, index=False)
|
69 |
+
|
70 |
+
|
71 |
+
# Function to load the appropriate model based on the user's selection
|
72 |
+
def load_model(modality, mode):
|
73 |
+
# For Few-Shot classification, always use the DINOv2 model
|
74 |
+
if mode == "Few-Shot":
|
75 |
+
class Args:
|
76 |
+
model = 'DINOv2'
|
77 |
+
pretrained = 'pretrained'
|
78 |
+
frozen = 'unfrozen'
|
79 |
+
|
80 |
+
args = Args()
|
81 |
+
model = get_model(args) # Load DINOv2 model for Few-Shot classification
|
82 |
+
else:
|
83 |
+
# For Fully-Supervised classification, choose model based on the sensing modality
|
84 |
+
if modality == "Texture":
|
85 |
+
class Args:
|
86 |
+
model = 'DINOv2'
|
87 |
+
pretrained = 'pretrained'
|
88 |
+
frozen = 'unfrozen'
|
89 |
+
|
90 |
+
args = Args()
|
91 |
+
model = get_model(args) # Load DINOv2 model for Texture modality
|
92 |
+
elif modality == "Heightmap":
|
93 |
+
class Args:
|
94 |
+
model = 'ResNet152'
|
95 |
+
pretrained = 'pretrained'
|
96 |
+
frozen = 'unfrozen'
|
97 |
+
|
98 |
+
args = Args()
|
99 |
+
model = get_model(args) # Load ResNet152 model for Heightmap modality
|
100 |
+
else:
|
101 |
+
raise ValueError("Invalid modality selected!")
|
102 |
+
|
103 |
+
model.eval() # Set the model to evaluation mode
|
104 |
+
return model
|
105 |
+
|
106 |
+
|
107 |
+
# Prediction function that processes the image and returns the prediction results
|
108 |
+
def predict(image, modality, mode):
|
109 |
+
# Load the appropriate model based on the user's selections
|
110 |
+
model = load_model(modality, mode)
|
111 |
+
|
112 |
+
# Print the selected mode and modality for debugging purposes
|
113 |
+
print(f"User selected Mode: {mode}, Modality: {modality}")
|
114 |
+
|
115 |
+
# Preprocess the image
|
116 |
+
transform = torchvision.transforms.Compose([
|
117 |
+
torchvision.transforms.Resize((224, 224)),
|
118 |
+
torchvision.transforms.ToTensor(),
|
119 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
120 |
+
])
|
121 |
+
|
122 |
+
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
123 |
+
with torch.no_grad():
|
124 |
+
output = model(image_tensor) # Get model predictions
|
125 |
+
probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
|
126 |
+
|
127 |
+
# Class names for the predictions
|
128 |
+
class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
|
129 |
+
|
130 |
+
# Pair class names with their corresponding probabilities
|
131 |
+
predicted_class = class_names[probabilities.index(max(probabilities))] # Get the predicted class
|
132 |
+
results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
133 |
+
|
134 |
+
return predicted_class, results # Return predicted class and probabilities
|
135 |
+
|
136 |
+
|
137 |
+
# Create the Gradio interface using gr.Blocks
|
138 |
+
def create_interface():
|
139 |
+
with gr.Blocks() as interface:
|
140 |
+
# Title at the top of the interface (centered and larger)
|
141 |
+
gr.Markdown("<h1 style='text-align: center; font-size: 36px;'>LUWA Dataset Image Classification</h1>")
|
142 |
+
|
143 |
+
# Add description for the interface
|
144 |
+
description = """
|
145 |
+
### Image Classification Options
|
146 |
+
- **Fully-Supervised Classification**: Choose this for common or well-known materials with plenty of data (e.g., bone, wood).
|
147 |
+
- **Few-Shot Classification**: Choose this for rare or newly discovered materials where only a few examples exist.
|
148 |
+
### **Don't forget to choose the Sensing Modality based on your uploaded images.**
|
149 |
+
### **Please help us to flag the correct class for your uploaded image if you know it, it will help us to further develop our dataset. If you cannot find the correct class in the option, please click on the option 'Other' and type the correct class for us!**
|
150 |
+
"""
|
151 |
+
gr.Markdown(description)
|
152 |
+
|
153 |
+
# Top-level selector for Fully-Supervised vs. Few-Shot classification
|
154 |
+
mode_selector = gr.Radio(choices=["Fully Supervised", "Few-Shot"], label="Classification Mode",
|
155 |
+
value="Fully Supervised")
|
156 |
+
|
157 |
+
# Sensing modality selector
|
158 |
+
modality_selector = gr.Radio(choices=["Texture", "Heightmap"], label="Sensing Modality", value="Texture")
|
159 |
+
|
160 |
+
# Image upload input
|
161 |
+
image_input = gr.Image(type="pil", label="Image")
|
162 |
+
|
163 |
+
# Predicted classification output and class probabilities
|
164 |
+
with gr.Row():
|
165 |
+
predicted_output = gr.Label(num_top_classes=1, label="Predicted Classification")
|
166 |
+
probabilities_output = gr.Label(label="Prediction Probabilities")
|
167 |
+
|
168 |
+
# Add the "Run Prediction" button under the Prediction Probabilities
|
169 |
+
predict_button = gr.Button("Run Prediction")
|
170 |
+
|
171 |
+
# Dropdown for user to select the correct class if the model prediction is wrong
|
172 |
+
correct_class_selector = gr.Radio(
|
173 |
+
choices=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD", "Other"],
|
174 |
+
label="Select Correct Class"
|
175 |
+
)
|
176 |
+
|
177 |
+
# Text box for user to type the correct class if "Other" is selected
|
178 |
+
other_class_input = gr.Textbox(label="If Other, enter the correct class", visible=False)
|
179 |
+
|
180 |
+
# Logic to dynamically update visibility of the "Other" class text box
|
181 |
+
def update_visibility(selected_class):
|
182 |
+
return gr.update(visible=selected_class == "Other")
|
183 |
+
|
184 |
+
correct_class_selector.change(fn=update_visibility, inputs=correct_class_selector, outputs=other_class_input)
|
185 |
+
|
186 |
+
|
187 |
+
# Create a flagging instance
|
188 |
+
flagging_instance = CustomFlagging(dir_name="flagged_data")
|
189 |
+
|
190 |
+
# Define function for the confirmation pop-up
|
191 |
+
def confirm_flag_selection(correct_class, other_class):
|
192 |
+
# Generate confirmation message
|
193 |
+
if correct_class == "Other":
|
194 |
+
message = f"Are you sure the class you selected is '{other_class}' for this picture?"
|
195 |
+
else:
|
196 |
+
message = f"Are you sure the class you selected is '{correct_class}' for this picture?"
|
197 |
+
|
198 |
+
return message, gr.update(visible=True), gr.update(visible=True)
|
199 |
+
|
200 |
+
# Final flag submission function
|
201 |
+
def flag_data_save(correct_class, other_class, mode, image, modality, predicted_class, confirmed):
|
202 |
+
if confirmed == "Yes":
|
203 |
+
# Save the flagged data
|
204 |
+
correct_class_final = correct_class if correct_class != "Other" else other_class
|
205 |
+
flagging_instance.flag([mode, image, modality, predicted_class, correct_class_final])
|
206 |
+
return "Flagged successfully!"
|
207 |
+
else:
|
208 |
+
return "No flag submitted, please select again."
|
209 |
+
|
210 |
+
# Flagging button
|
211 |
+
flag_button = gr.Button("Flag")
|
212 |
+
|
213 |
+
# Confirmation box for user input and confirmation flag
|
214 |
+
confirmation_text = gr.Textbox(visible=False)
|
215 |
+
yes_no_choice = gr.Radio(choices=["Yes", "No"], label="Are you sure?", visible=False)
|
216 |
+
confirmation_button = gr.Button("Confirm Flag", visible=False)
|
217 |
+
|
218 |
+
# Prediction action
|
219 |
+
predict_button.click(
|
220 |
+
fn=predict,
|
221 |
+
inputs=[image_input, modality_selector, mode_selector],
|
222 |
+
outputs=[predicted_output, probabilities_output]
|
223 |
+
)
|
224 |
+
|
225 |
+
# Flagging action with confirmation
|
226 |
+
flag_button.click(
|
227 |
+
fn=confirm_flag_selection,
|
228 |
+
inputs=[correct_class_selector, other_class_input],
|
229 |
+
outputs=[confirmation_text, yes_no_choice, confirmation_button]
|
230 |
+
)
|
231 |
+
|
232 |
+
# Final flag submission after confirmation
|
233 |
+
confirmation_button.click(
|
234 |
+
fn=flag_data_save,
|
235 |
+
inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector,
|
236 |
+
predicted_output, yes_no_choice],
|
237 |
+
outputs=gr.Textbox(label="Flagging Status")
|
238 |
+
)
|
239 |
+
|
240 |
+
# Visitor count displayed at the bottom
|
241 |
+
visitor_count = update_visitor_count() # Update the visitor count
|
242 |
+
gr.Markdown(f"### Number of Visitors: {visitor_count}") # Display visitor count
|
243 |
+
|
244 |
+
return interface
|
245 |
+
|
246 |
+
|
247 |
+
if __name__ == "__main__":
|
248 |
+
interface = create_interface()
|
249 |
+
interface.launch(share=True)
|