File size: 11,008 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import cv2
import numpy as np
import gradio as gr
import cwm.utils as utils

# Points color and arrow properties
arrow_color = (0, 255, 0)  # Green color for all arrows
dot_color = (0, 255, 0)  # Green color for the dots at start and end
dot_color_fixed = (255, 0, 0)  # Red color for zero-length vectors
thickness = 3  # Thickness of the arrow
tip_length = 0.3  # The length of the arrow tip relative to the arrow length
dot_radius = 7  # Radius for the dots
dot_thickness = -1  # Thickness for solid circle (-1 fills the circle)
from PIL import Image
import torch
#load model
from cwm.model.model_factory import model_factory

from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Load CWM 3-frame model (automatically download pre-trained checkpoint)
model = model_factory.load_model('vitb_8x8patch_3frames').to(device)

model.requires_grad_(False)
model.eval()

model = model.to(torch.float16)


import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from PIL import Image
import numpy as np

def draw_arrows_matplotlib(img, selected_points, zero_length):
    """
    Draw arrows on the image using matplotlib for better quality arrows and dots.
    """
    fig, ax = plt.subplots()
    ax.imshow(img)

    for i in range(0, len(selected_points), 2):
        start_point = selected_points[i]
        end_point = selected_points[i + 1]

        if start_point == end_point or zero_length:
            # Draw a dot for zero-length vectors or if only one point is clicked
            ax.scatter(start_point[0], start_point[1], color='red', s=100)  # Red dot for zero-length vector
        else:
            # Draw arrows
            arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]),
                                    color='green', linewidth=2, arrowstyle='->', mutation_scale=15)
            ax.add_patch(arrow)

            # Optionally, draw a small circle (dot) at the start and end points
            ax.scatter(start_point[0], start_point[1], color='green', s=100)  # Green dot at start
            ax.scatter(end_point[0], end_point[1], color='green', s=100)  # Green dot at end

    # Save the image to a numpy array
    fig.canvas.draw()
    img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)
    return img_array

from PIL import ImageDraw
with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown('''# Generate interventions!🚀
        Upload an image and click to select the start and end points for arrows. Dots will be shown at the beginning and end of each arrow. You can also create zero-length vectors (just a dot) by enabling the toggle below.
        ''')

    # Annotating arrows on an image
    with gr.Tab(label='Image'):
        with gr.Row():
            with gr.Column():
                # Input image
                original_image = gr.State(value=None)  # store original image without arrows
                input_image = gr.Image(type="numpy", label="Upload Image")

                # Annotate arrows
                selected_points = gr.State([])  # store points
                zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False)  # Toggle for zero-length vectors
                with gr.Row():
                    gr.Markdown('Click on the image to select the start and end points for each arrow. If zero-length vectors are enabled, clicking once will draw a dot.')
                    undo_button = gr.Button('Undo last action')
                    clear_button = gr.Button('Clear All')

                # Run model button
                run_model_button = gr.Button('Run Model')

            # Show the image with the annotated arrows
            with gr.Tab(label='Intervention'):
                output_image = gr.Image(type='numpy')

        def resize_to_square(img, size=512):
            img = Image.fromarray(img)
            img = img.resize((size, size))
            return np.array(img)

        # Store the original image and resize to square size once uploaded
        def store_img(img):
            resized_img = resize_to_square(img)  # Resize the uploaded image to a square
            print(f"Image uploaded with shape: {resized_img.shape}")
            return resized_img, resized_img, []

        input_image.upload(store_img, [input_image], [input_image, original_image, selected_points])

        # Get points and draw arrows or zero-length vectors based on the toggle
        def get_point(img, sel_pix, zero_length, evt: gr.SelectData):
            sel_pix.append(evt.index)  # Append the point's location (coordinates)

            pil_img = Image.fromarray(img)
            draw = ImageDraw.Draw(pil_img)
            # Zero-length vector case: Draw a single dot at the clicked point
            if zero_length:
                point = sel_pix[-1]  # Last point clicked
                draw.ellipse([point[0] - dot_radius, point[1] - dot_radius,
                              point[0] + dot_radius, point[1] + dot_radius],
                             fill=dot_color_fixed)
                sel_pix.append(evt.index)
            else:
                # Regular case: two clicks for an arrow
                # Check if this is the first point (start point for the arrow)
                if len(sel_pix) % 2 == 1:
                    # Draw a dot at the start point to give feedback
                    start_point = sel_pix[-1]  # Last point is the start
                    draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius,
                                  start_point[0] + dot_radius, start_point[1] + dot_radius],
                                 fill=dot_color)

                # Check if two points have been selected (start and end points for an arrow)
                if len(sel_pix) % 2 == 0:
                    # Draw an arrow between the last two points
                    start_point = tuple(sel_pix[-2])  # Second last point is the start
                    end_point = tuple(sel_pix[-1])  # Last point is the end

                    # Draw arrow
                    draw.line([start_point, end_point], fill=arrow_color, width=thickness)

                    # Draw a dot at the end point
                    draw.ellipse([end_point[0] - dot_radius, end_point[1] - dot_radius,
                                  end_point[0] + dot_radius, end_point[1] + dot_radius],
                                 fill=dot_color)

            return np.array(pil_img)

        input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image])

        # Undo the last selected action
        def undo_arrows(orig_img, sel_pix, zero_length):
            temp = orig_img.copy()
            # if zero_length:
            #     # Undo the last zero-length vector (just the last dot)
            #     if len(sel_pix) >= 1:
            #         sel_pix.pop()  # Remove the last point
            # else:
            pil_img = Image.fromarray(temp)
            draw = ImageDraw.Draw(pil_img)
            if len(sel_pix) >= 2:
                sel_pix.pop()  # Remove the last end point
                sel_pix.pop()  # Remove the last start point

            # Redraw all remaining arrows and dots
            for i in range(0, len(sel_pix), 2):
                start_point = sel_pix[i]
                end_point = sel_pix[i + 1]
                if start_point == end_point:
                    # Zero-length vector: Draw a dot
                    color = dot_color_fixed
                else:
                    draw.line([tuple(start_point), tuple(end_point)], fill=arrow_color, width=thickness)
                    color = arrow_color
                # Draw arrow

                # Draw dots at start and end points
                draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius,
                              start_point[0] + dot_radius, start_point[1] + dot_radius],
                             fill=color)
                draw.ellipse([end_point[0] - dot_radius, end_point[1] - dot_radius,
                              end_point[0] + dot_radius, end_point[1] + dot_radius],
                             fill=color)

            # If there is an odd number of points (e.g., only a start point), draw a dot for it
            if len(sel_pix) == 1:
                start_point = sel_pix[0]
                draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius,
                              start_point[0] + dot_radius, start_point[1] + dot_radius],
                             fill=dot_color)

            return np.array(pil_img)

        undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image])

        # Clear all points and reset the image
        def clear_all_points(orig_img, sel_pix):
            sel_pix.clear()  # Clear all points
            return orig_img  # Reset image to original

        clear_button.click(clear_all_points, [original_image, selected_points], [input_image])

        # Dummy model function to simulate running a model
        def run_model_on_points(points, input_image, original_image):
            H = input_image.shape[0]
            W = input_image.shape[1]
            factor = 224/H
            # Example: pretend the model processes points and returns a simple transformation on the image
            points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor

            points = points[:, [1, 0, 3, 2]]

            print(points)

            img = Image.fromarray(original_image)

            img = img.resize((224, 224))

            img = np.array(img)

            np.save("img.npy", original_image)

            img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0

            img = img[None]

            # reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224
            x = img[:, :, None].expand(-1, -1, 3, -1, -1).to(torch.float16)

            # Imagenet-normalize the inputs (standardization)
            x = utils.imagenet_normalize(x).to(device)
            with torch.no_grad():
                counterfactual = model.get_counterfactual(x, points)

            counterfactual = counterfactual.squeeze()

            counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy()

            # for i in range(0, len(points), 2):
            #     # Draw rectangles on the points as model output example
            #     cv2.rectangle(processed_image, points[i], points[i + 1], (255, 0, 0), 3)
            return counterfactual

        # Run model when the button is clicked
        run_model_button.click(run_model_on_points, [selected_points, input_image, original_image], [output_image])

    # Launch the app
demo.queue().launch(inbrowser=True, share=True)