|
import cv2 |
|
import numpy as np |
|
import gradio as gr |
|
import cwm.utils as utils |
|
|
|
|
|
arrow_color = (0, 255, 0) |
|
dot_color = (0, 255, 0) |
|
dot_color_fixed = (255, 0, 0) |
|
thickness = 3 |
|
tip_length = 0.3 |
|
dot_radius = 7 |
|
dot_thickness = -1 |
|
from PIL import Image |
|
import torch |
|
|
|
from cwm.model.model_factory import model_factory |
|
|
|
from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
model = model_factory.load_model('vitb_8x8patch_3frames').to(device) |
|
|
|
model.requires_grad_(False) |
|
model.eval() |
|
|
|
model = model.to(torch.float16) |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import FancyArrowPatch |
|
from PIL import Image |
|
import numpy as np |
|
|
|
def draw_arrows_matplotlib(img, selected_points, zero_length): |
|
""" |
|
Draw arrows on the image using matplotlib for better quality arrows and dots. |
|
""" |
|
fig, ax = plt.subplots() |
|
ax.imshow(img) |
|
|
|
for i in range(0, len(selected_points), 2): |
|
start_point = selected_points[i] |
|
end_point = selected_points[i + 1] |
|
|
|
if start_point == end_point or zero_length: |
|
|
|
ax.scatter(start_point[0], start_point[1], color='red', s=100) |
|
else: |
|
|
|
arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), |
|
color='green', linewidth=2, arrowstyle='->', mutation_scale=15) |
|
ax.add_patch(arrow) |
|
|
|
|
|
ax.scatter(start_point[0], start_point[1], color='green', s=100) |
|
ax.scatter(end_point[0], end_point[1], color='green', s=100) |
|
|
|
|
|
fig.canvas.draw() |
|
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close(fig) |
|
return img_array |
|
|
|
from PIL import ImageDraw |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown('''# Generate interventions!π |
|
Upload an image and click to select the start and end points for arrows. Dots will be shown at the beginning and end of each arrow. You can also create zero-length vectors (just a dot) by enabling the toggle below. |
|
''') |
|
|
|
|
|
with gr.Tab(label='Image'): |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
original_image = gr.State(value=None) |
|
input_image = gr.Image(type="numpy", label="Upload Image") |
|
|
|
|
|
selected_points = gr.State([]) |
|
zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False) |
|
with gr.Row(): |
|
gr.Markdown('Click on the image to select the start and end points for each arrow. If zero-length vectors are enabled, clicking once will draw a dot.') |
|
undo_button = gr.Button('Undo last action') |
|
clear_button = gr.Button('Clear All') |
|
|
|
|
|
run_model_button = gr.Button('Run Model') |
|
|
|
|
|
with gr.Tab(label='Intervention'): |
|
output_image = gr.Image(type='numpy') |
|
|
|
def resize_to_square(img, size=512): |
|
img = Image.fromarray(img) |
|
img = img.resize((size, size)) |
|
return np.array(img) |
|
|
|
|
|
def store_img(img): |
|
resized_img = resize_to_square(img) |
|
print(f"Image uploaded with shape: {resized_img.shape}") |
|
return resized_img, resized_img, [] |
|
|
|
input_image.upload(store_img, [input_image], [input_image, original_image, selected_points]) |
|
|
|
|
|
def get_point(img, sel_pix, zero_length, evt: gr.SelectData): |
|
sel_pix.append(evt.index) |
|
|
|
pil_img = Image.fromarray(img) |
|
draw = ImageDraw.Draw(pil_img) |
|
|
|
if zero_length: |
|
point = sel_pix[-1] |
|
draw.ellipse([point[0] - dot_radius, point[1] - dot_radius, |
|
point[0] + dot_radius, point[1] + dot_radius], |
|
fill=dot_color_fixed) |
|
sel_pix.append(evt.index) |
|
else: |
|
|
|
|
|
if len(sel_pix) % 2 == 1: |
|
|
|
start_point = sel_pix[-1] |
|
draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius, |
|
start_point[0] + dot_radius, start_point[1] + dot_radius], |
|
fill=dot_color) |
|
|
|
|
|
if len(sel_pix) % 2 == 0: |
|
|
|
start_point = tuple(sel_pix[-2]) |
|
end_point = tuple(sel_pix[-1]) |
|
|
|
|
|
draw.line([start_point, end_point], fill=arrow_color, width=thickness) |
|
|
|
|
|
draw.ellipse([end_point[0] - dot_radius, end_point[1] - dot_radius, |
|
end_point[0] + dot_radius, end_point[1] + dot_radius], |
|
fill=dot_color) |
|
|
|
return np.array(pil_img) |
|
|
|
input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) |
|
|
|
|
|
def undo_arrows(orig_img, sel_pix, zero_length): |
|
temp = orig_img.copy() |
|
|
|
|
|
|
|
|
|
|
|
pil_img = Image.fromarray(temp) |
|
draw = ImageDraw.Draw(pil_img) |
|
if len(sel_pix) >= 2: |
|
sel_pix.pop() |
|
sel_pix.pop() |
|
|
|
|
|
for i in range(0, len(sel_pix), 2): |
|
start_point = sel_pix[i] |
|
end_point = sel_pix[i + 1] |
|
if start_point == end_point: |
|
|
|
color = dot_color_fixed |
|
else: |
|
draw.line([tuple(start_point), tuple(end_point)], fill=arrow_color, width=thickness) |
|
color = arrow_color |
|
|
|
|
|
|
|
draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius, |
|
start_point[0] + dot_radius, start_point[1] + dot_radius], |
|
fill=color) |
|
draw.ellipse([end_point[0] - dot_radius, end_point[1] - dot_radius, |
|
end_point[0] + dot_radius, end_point[1] + dot_radius], |
|
fill=color) |
|
|
|
|
|
if len(sel_pix) == 1: |
|
start_point = sel_pix[0] |
|
draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius, |
|
start_point[0] + dot_radius, start_point[1] + dot_radius], |
|
fill=dot_color) |
|
|
|
return np.array(pil_img) |
|
|
|
undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) |
|
|
|
|
|
def clear_all_points(orig_img, sel_pix): |
|
sel_pix.clear() |
|
return orig_img |
|
|
|
clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) |
|
|
|
|
|
def run_model_on_points(points, input_image, original_image): |
|
H = input_image.shape[0] |
|
W = input_image.shape[1] |
|
factor = 224/H |
|
|
|
points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor |
|
|
|
points = points[:, [1, 0, 3, 2]] |
|
|
|
print(points) |
|
|
|
img = Image.fromarray(original_image) |
|
|
|
img = img.resize((224, 224)) |
|
|
|
img = np.array(img) |
|
|
|
np.save("img.npy", original_image) |
|
|
|
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 |
|
|
|
img = img[None] |
|
|
|
|
|
x = img[:, :, None].expand(-1, -1, 3, -1, -1).to(torch.float16) |
|
|
|
|
|
x = utils.imagenet_normalize(x).to(device) |
|
with torch.no_grad(): |
|
counterfactual = model.get_counterfactual(x, points) |
|
|
|
counterfactual = counterfactual.squeeze() |
|
|
|
counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() |
|
|
|
|
|
|
|
|
|
return counterfactual |
|
|
|
|
|
run_model_button.click(run_model_on_points, [selected_points, input_image, original_image], [output_image]) |
|
|
|
|
|
demo.queue().launch(inbrowser=True, share=True) |
|
|