Spaces:
Sleeping
Sleeping
File size: 10,381 Bytes
e073b0b 9e56473 e073b0b 9e56473 e073b0b 645dad2 e073b0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
import gradio as gr
import torch
import torchvision
import pandas as pd
import os
from PIL import Image
from utils.experiment_utils import get_model
# File to store the visitor count
visitor_count_file = "visitor_count.txt"
# Function to update visitor count
def update_visitor_count():
if os.path.exists(visitor_count_file):
with open(visitor_count_file, "r") as file:
count = int(file.read())
else:
count = 0 # Start from zero if no file exists
# Increment visitor count
count += 1
# Save the updated count back to the file
with open(visitor_count_file, "w") as file:
file.write(str(count))
return count
# Custom flagging logic to save flagged data to a CSV file
class CustomFlagging(gr.FlaggingCallback):
def __init__(self, dir_name="flagged_data"):
self.dir = dir_name
self.image_dir = os.path.join(self.dir, "uploaded_images")
if not os.path.exists(self.dir):
os.makedirs(self.dir)
if not os.path.exists(self.image_dir):
os.makedirs(self.image_dir)
# Define setup as a no-op to fulfill abstract class requirement
def setup(self, *args, **kwargs):
pass
def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
# Extract data
classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data
# Save the uploaded image in the "uploaded_images" folder
image_filename = os.path.join(self.image_dir,
f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png")
image.save(image_filename) # Save image in PNG format
# Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class
data = {
"Classification Mode": classification_mode,
"Image Path": image_filename, # Save path to image in CSV
"Sensing Modality": sensing_modality,
"Predicted Class": predicted_class,
"Correct Class": correct_class,
}
df = pd.DataFrame([data])
csv_file = os.path.join(self.dir, "flagged_data.csv")
# Append to CSV, or create if it doesn't exist
if os.path.exists(csv_file):
df.to_csv(csv_file, mode='a', header=False, index=False)
else:
df.to_csv(csv_file, mode='w', header=True, index=False)
# Function to load the appropriate model based on the user's selection
def load_model(modality, mode):
# For Few-Shot classification, always use the DINOv2 model
if mode == "Few-Shot":
class Args:
model = 'DINOv2'
pretrained = 'pretrained'
frozen = 'unfrozen'
args = Args()
model = get_model(args) # Load DINOv2 model for Few-Shot classification
else:
# For Fully-Supervised classification, choose model based on the sensing modality
if modality == "Texture":
class Args:
model = 'DINOv2'
pretrained = 'pretrained'
frozen = 'unfrozen'
args = Args()
model = get_model(args) # Load DINOv2 model for Texture modality
elif modality == "Heightmap":
class Args:
model = 'ResNet152'
pretrained = 'pretrained'
frozen = 'unfrozen'
args = Args()
model = get_model(args) # Load ResNet152 model for Heightmap modality
else:
raise ValueError("Invalid modality selected!")
model.eval() # Set the model to evaluation mode
return model
# Prediction function that processes the image and returns the prediction results
def predict(image, modality, mode):
# Load the appropriate model based on the user's selections
model = load_model(modality, mode)
# Print the selected mode and modality for debugging purposes
print(f"User selected Mode: {mode}, Modality: {modality}")
# Preprocess the image
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image_tensor) # Get model predictions
probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
# Class names for the predictions
class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
# Pair class names with their corresponding probabilities
predicted_class = class_names[probabilities.index(max(probabilities))] # Get the predicted class
results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
return predicted_class, results # Return predicted class and probabilities
# Create the Gradio interface using gr.Blocks
def create_interface():
with gr.Blocks() as interface:
# Title at the top of the interface (centered and larger)
gr.Markdown("<h1 style='text-align: center; font-size: 36px;'>LUWA Dataset Image Classification</h1>")
# Add description for the interface
description = """
### Image Classification Options
- **Fully-Supervised Classification**: Choose this for common or well-known materials with plenty of data (e.g., bone, wood).
- **Few-Shot Classification**: Choose this for rare or newly discovered materials where only a few examples exist.
### **Don't forget to choose the Sensing Modality based on your uploaded images.**
### **Please help us to flag the correct class for your uploaded image if you know it, it will help us to further develop our dataset. If you cannot find the correct class in the option, please click on the option 'Other' and type the correct class for us!**
"""
gr.Markdown(description)
# Top-level selector for Fully-Supervised vs. Few-Shot classification
mode_selector = gr.Radio(choices=["Fully Supervised", "Few-Shot"], label="Classification Mode",
value="Fully Supervised")
# Sensing modality selector
modality_selector = gr.Radio(choices=["Texture", "Heightmap"], label="Sensing Modality", value="Texture")
# Image upload input
image_input = gr.Image(type="pil", label="Image")
# Predicted classification output and class probabilities
with gr.Row():
predicted_output = gr.Label(num_top_classes=1, label="Predicted Classification")
probabilities_output = gr.Label(label="Prediction Probabilities")
# Add the "Run Prediction" button under the Prediction Probabilities
predict_button = gr.Button("Run Prediction")
# Dropdown for user to select the correct class if the model prediction is wrong
correct_class_selector = gr.Radio(
choices=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD", "Other"],
label="Select Correct Class"
)
# Text box for user to type the correct class if "Other" is selected
other_class_input = gr.Textbox(label="If Other, enter the correct class", visible=False)
# Logic to dynamically update visibility of the "Other" class text box
def update_visibility(selected_class):
return gr.update(visible=selected_class == "Other")
correct_class_selector.change(fn=update_visibility, inputs=correct_class_selector, outputs=other_class_input)
# Create a flagging instance
flagging_instance = CustomFlagging(dir_name="flagged_data")
# Define function for the confirmation pop-up
def confirm_flag_selection(correct_class, other_class):
# Generate confirmation message
if correct_class == "Other":
message = f"Are you sure the class you selected is '{other_class}' for this picture?"
else:
message = f"Are you sure the class you selected is '{correct_class}' for this picture?"
return message, gr.update(visible=True), gr.update(visible=True)
# Final flag submission function
def flag_data_save(correct_class, other_class, mode, image, modality, predicted_class, confirmed):
if confirmed == "Yes":
# Save the flagged data
correct_class_final = correct_class if correct_class != "Other" else other_class
flagging_instance.flag([mode, image, modality, predicted_class, correct_class_final])
return "Flagged successfully!"
else:
return "No flag submitted, please select again."
# Flagging button
flag_button = gr.Button("Flag")
# Confirmation box for user input and confirmation flag
confirmation_text = gr.Textbox(visible=False)
yes_no_choice = gr.Radio(choices=["Yes", "No"], label="Are you sure?", visible=False)
confirmation_button = gr.Button("Confirm Flag", visible=False)
# Prediction action
predict_button.click(
fn=predict,
inputs=[image_input, modality_selector, mode_selector],
outputs=[predicted_output, probabilities_output]
)
# Flagging action with confirmation
flag_button.click(
fn=confirm_flag_selection,
inputs=[correct_class_selector, other_class_input],
outputs=[confirmation_text, yes_no_choice, confirmation_button]
)
# Final flag submission after confirmation
confirmation_button.click(
fn=flag_data_save,
inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector,
predicted_output, yes_no_choice],
outputs=gr.Textbox(label="Flagging Status")
)
# Visitor count displayed at the bottom
visitor_count = update_visitor_count() # Update the visitor count
gr.Markdown(f"### The Number of Visitors since October 2024: {visitor_count}") # Display visitor count
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=True)
|