File size: 10,381 Bytes
e073b0b
9e56473
e073b0b
 
9e56473
 
e073b0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645dad2
e073b0b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import gradio as gr
import torch
import torchvision
import pandas as pd
import os
from PIL import Image
from utils.experiment_utils import get_model

# File to store the visitor count
visitor_count_file = "visitor_count.txt"

# Function to update visitor count
def update_visitor_count():
    if os.path.exists(visitor_count_file):
        with open(visitor_count_file, "r") as file:
            count = int(file.read())
    else:
        count = 0  # Start from zero if no file exists

    # Increment visitor count
    count += 1

    # Save the updated count back to the file
    with open(visitor_count_file, "w") as file:
        file.write(str(count))

    return count

# Custom flagging logic to save flagged data to a CSV file
class CustomFlagging(gr.FlaggingCallback):
    def __init__(self, dir_name="flagged_data"):
        self.dir = dir_name
        self.image_dir = os.path.join(self.dir, "uploaded_images")
        if not os.path.exists(self.dir):
            os.makedirs(self.dir)
        if not os.path.exists(self.image_dir):
            os.makedirs(self.image_dir)

    # Define setup as a no-op to fulfill abstract class requirement
    def setup(self, *args, **kwargs):
        pass

    def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
        # Extract data
        classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data

        # Save the uploaded image in the "uploaded_images" folder
        image_filename = os.path.join(self.image_dir,
                                      f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png")
        image.save(image_filename)  # Save image in PNG format

        # Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class
        data = {
            "Classification Mode": classification_mode,
            "Image Path": image_filename,  # Save path to image in CSV
            "Sensing Modality": sensing_modality,
            "Predicted Class": predicted_class,
            "Correct Class": correct_class,
        }

        df = pd.DataFrame([data])
        csv_file = os.path.join(self.dir, "flagged_data.csv")

        # Append to CSV, or create if it doesn't exist
        if os.path.exists(csv_file):
            df.to_csv(csv_file, mode='a', header=False, index=False)
        else:
            df.to_csv(csv_file, mode='w', header=True, index=False)


# Function to load the appropriate model based on the user's selection
def load_model(modality, mode):
    # For Few-Shot classification, always use the DINOv2 model
    if mode == "Few-Shot":
        class Args:
            model = 'DINOv2'
            pretrained = 'pretrained'
            frozen = 'unfrozen'

        args = Args()
        model = get_model(args)  # Load DINOv2 model for Few-Shot classification
    else:
        # For Fully-Supervised classification, choose model based on the sensing modality
        if modality == "Texture":
            class Args:
                model = 'DINOv2'
                pretrained = 'pretrained'
                frozen = 'unfrozen'

            args = Args()
            model = get_model(args)  # Load DINOv2 model for Texture modality
        elif modality == "Heightmap":
            class Args:
                model = 'ResNet152'
                pretrained = 'pretrained'
                frozen = 'unfrozen'

            args = Args()
            model = get_model(args)  # Load ResNet152 model for Heightmap modality
        else:
            raise ValueError("Invalid modality selected!")

    model.eval()  # Set the model to evaluation mode
    return model


# Prediction function that processes the image and returns the prediction results
def predict(image, modality, mode):
    # Load the appropriate model based on the user's selections
    model = load_model(modality, mode)

    # Print the selected mode and modality for debugging purposes
    print(f"User selected Mode: {mode}, Modality: {modality}")

    # Preprocess the image
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        output = model(image_tensor)  # Get model predictions
        probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()

    # Class names for the predictions
    class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]

    # Pair class names with their corresponding probabilities
    predicted_class = class_names[probabilities.index(max(probabilities))]  # Get the predicted class
    results = {class_names[i]: probabilities[i] for i in range(len(class_names))}

    return predicted_class, results  # Return predicted class and probabilities


# Create the Gradio interface using gr.Blocks
def create_interface():
    with gr.Blocks() as interface:
        # Title at the top of the interface (centered and larger)
        gr.Markdown("<h1 style='text-align: center; font-size: 36px;'>LUWA Dataset Image Classification</h1>")

        # Add description for the interface
        description = """
        ### Image Classification Options
        - **Fully-Supervised Classification**: Choose this for common or well-known materials with plenty of data (e.g., bone, wood).
        - **Few-Shot Classification**: Choose this for rare or newly discovered materials where only a few examples exist.
        ### **Don't forget to choose the Sensing Modality based on your uploaded images.**
        ### **Please help us to flag the correct class for your uploaded image if you know it, it will help us to further develop our dataset. If you cannot find the correct class in the option, please click on the option 'Other' and type the correct class for us!**
        """
        gr.Markdown(description)

        # Top-level selector for Fully-Supervised vs. Few-Shot classification
        mode_selector = gr.Radio(choices=["Fully Supervised", "Few-Shot"], label="Classification Mode",
                                 value="Fully Supervised")

        # Sensing modality selector
        modality_selector = gr.Radio(choices=["Texture", "Heightmap"], label="Sensing Modality", value="Texture")

        # Image upload input
        image_input = gr.Image(type="pil", label="Image")

        # Predicted classification output and class probabilities
        with gr.Row():
            predicted_output = gr.Label(num_top_classes=1, label="Predicted Classification")
            probabilities_output = gr.Label(label="Prediction Probabilities")

        # Add the "Run Prediction" button under the Prediction Probabilities
        predict_button = gr.Button("Run Prediction")

        # Dropdown for user to select the correct class if the model prediction is wrong
        correct_class_selector = gr.Radio(
            choices=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD", "Other"],
            label="Select Correct Class"
        )

        # Text box for user to type the correct class if "Other" is selected
        other_class_input = gr.Textbox(label="If Other, enter the correct class", visible=False)

        # Logic to dynamically update visibility of the "Other" class text box
        def update_visibility(selected_class):
            return gr.update(visible=selected_class == "Other")

        correct_class_selector.change(fn=update_visibility, inputs=correct_class_selector, outputs=other_class_input)


        # Create a flagging instance
        flagging_instance = CustomFlagging(dir_name="flagged_data")

        # Define function for the confirmation pop-up
        def confirm_flag_selection(correct_class, other_class):
            # Generate confirmation message
            if correct_class == "Other":
                message = f"Are you sure the class you selected is '{other_class}' for this picture?"
            else:
                message = f"Are you sure the class you selected is '{correct_class}' for this picture?"

            return message, gr.update(visible=True), gr.update(visible=True)

        # Final flag submission function
        def flag_data_save(correct_class, other_class, mode, image, modality, predicted_class, confirmed):
            if confirmed == "Yes":
                # Save the flagged data
                correct_class_final = correct_class if correct_class != "Other" else other_class
                flagging_instance.flag([mode, image, modality, predicted_class, correct_class_final])
                return "Flagged successfully!"
            else:
                return "No flag submitted, please select again."

        # Flagging button
        flag_button = gr.Button("Flag")

        # Confirmation box for user input and confirmation flag
        confirmation_text = gr.Textbox(visible=False)
        yes_no_choice = gr.Radio(choices=["Yes", "No"], label="Are you sure?", visible=False)
        confirmation_button = gr.Button("Confirm Flag", visible=False)

        # Prediction action
        predict_button.click(
            fn=predict,
            inputs=[image_input, modality_selector, mode_selector],
            outputs=[predicted_output, probabilities_output]
        )

        # Flagging action with confirmation
        flag_button.click(
            fn=confirm_flag_selection,
            inputs=[correct_class_selector, other_class_input],
            outputs=[confirmation_text, yes_no_choice, confirmation_button]
        )

        # Final flag submission after confirmation
        confirmation_button.click(
            fn=flag_data_save,
            inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector,
                    predicted_output, yes_no_choice],
            outputs=gr.Textbox(label="Flagging Status")
        )

        # Visitor count displayed at the bottom
        visitor_count = update_visitor_count()  # Update the visitor count
        gr.Markdown(f"### The Number of Visitors since October 2024: {visitor_count}")  # Display visitor count

    return interface


if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=True)