|
import cv2 |
|
import numpy as np |
|
import gradio as gr |
|
import cwm.utils as utils |
|
|
|
|
|
arrow_color = (0, 255, 0) |
|
dot_color = (0, 255, 0) |
|
dot_color_fixed = (255, 0, 0) |
|
thickness = 4 |
|
tip_length = 0.3 |
|
dot_radius = 10 |
|
dot_thickness = -1 |
|
from PIL import Image |
|
import torch |
|
|
|
from cwm.model.model_factory import model_factory |
|
|
|
from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
model = model_factory.load_model('vitb_8x8patch_3frames').to(device) |
|
|
|
model.requires_grad_(False) |
|
model.eval() |
|
|
|
model = model.to(torch.float16) |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import FancyArrowPatch |
|
from PIL import Image |
|
import numpy as np |
|
|
|
def draw_arrows_matplotlib(img, selected_points, zero_length): |
|
""" |
|
Draw arrows on the image using matplotlib for better quality arrows and dots. |
|
""" |
|
fig, ax = plt.subplots() |
|
ax.imshow(img) |
|
|
|
for i in range(0, len(selected_points), 2): |
|
start_point = selected_points[i] |
|
end_point = selected_points[i + 1] |
|
|
|
if start_point == end_point or zero_length: |
|
|
|
ax.scatter(start_point[0], start_point[1], color='red', s=100) |
|
else: |
|
|
|
arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), |
|
color='green', linewidth=2, arrowstyle='->', mutation_scale=15) |
|
ax.add_patch(arrow) |
|
|
|
|
|
ax.scatter(start_point[0], start_point[1], color='green', s=100) |
|
ax.scatter(end_point[0], end_point[1], color='green', s=100) |
|
|
|
|
|
fig.canvas.draw() |
|
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close(fig) |
|
return img_array |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown('''# Generate interventions!π |
|
Upload an image and click to select the start and end points for arrows. Dots will be shown at the beginning and end of each arrow. You can also create zero-length vectors (just a dot) by enabling the toggle below. |
|
''') |
|
|
|
|
|
with gr.Tab(label='Image'): |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
original_image = gr.State(value=None) |
|
input_image = gr.Image(type="numpy", label="Upload Image") |
|
|
|
|
|
selected_points = gr.State([]) |
|
zero_length_toggle = gr.Checkbox(label="Enable zero-length vectors", value=False) |
|
with gr.Row(): |
|
gr.Markdown('Click on the image to select the start and end points for each arrow. If zero-length vectors are enabled, clicking once will draw a dot.') |
|
undo_button = gr.Button('Undo last action') |
|
clear_button = gr.Button('Clear All') |
|
|
|
|
|
run_model_button = gr.Button('Run Model') |
|
|
|
|
|
with gr.Tab(label='Intervention'): |
|
output_image = gr.Image(type='numpy') |
|
|
|
def resize_to_square(img, size=512): |
|
img = Image.fromarray(img) |
|
img = img.resize((size, size)) |
|
return np.array(img) |
|
|
|
|
|
def store_img(img): |
|
resized_img = resize_to_square(img) |
|
print(f"Image uploaded with shape: {resized_img.shape}") |
|
return resized_img, resized_img, [] |
|
|
|
input_image.upload(store_img, [input_image], [input_image, original_image, selected_points]) |
|
|
|
|
|
def get_point(img, sel_pix, zero_length, evt: gr.SelectData): |
|
sel_pix.append(evt.index) |
|
|
|
|
|
if zero_length: |
|
point = sel_pix[-1] |
|
cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness) |
|
sel_pix.append(evt.index) |
|
else: |
|
|
|
|
|
if len(sel_pix) % 2 == 1: |
|
|
|
start_point = sel_pix[-1] |
|
cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness) |
|
|
|
|
|
if len(sel_pix) % 2 == 0: |
|
|
|
start_point = sel_pix[-2] |
|
end_point = sel_pix[-1] |
|
|
|
|
|
cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length) |
|
|
|
|
|
cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness) |
|
|
|
return img if isinstance(img, np.ndarray) else np.array(img) |
|
|
|
input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) |
|
|
|
|
|
def undo_arrows(orig_img, sel_pix, zero_length): |
|
temp = orig_img.copy() |
|
|
|
|
|
|
|
|
|
|
|
if len(sel_pix) >= 2: |
|
sel_pix.pop() |
|
sel_pix.pop() |
|
|
|
|
|
for i in range(0, len(sel_pix), 2): |
|
start_point = sel_pix[i] |
|
end_point = sel_pix[i + 1] |
|
if start_point == end_point: |
|
|
|
color = dot_color_fixed |
|
else: |
|
cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length) |
|
color = arrow_color |
|
|
|
|
|
|
|
cv2.circle(temp, start_point, dot_radius, color, dot_thickness) |
|
cv2.circle(temp, end_point, dot_radius, color, dot_thickness) |
|
|
|
|
|
if len(sel_pix) == 1: |
|
start_point = sel_pix[0] |
|
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness) |
|
|
|
return temp if isinstance(temp, np.ndarray) else np.array(temp) |
|
|
|
undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) |
|
|
|
|
|
def clear_all_points(orig_img, sel_pix): |
|
sel_pix.clear() |
|
return orig_img |
|
|
|
clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) |
|
|
|
|
|
def run_model_on_points(points, input_image, original_image): |
|
H = input_image.shape[0] |
|
W = input_image.shape[1] |
|
factor = 224/H |
|
|
|
points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor |
|
|
|
points = points[:, [1, 0, 3, 2]] |
|
|
|
print(points) |
|
|
|
img = Image.fromarray(original_image) |
|
|
|
img = img.resize((224, 224)) |
|
|
|
img = np.array(img) |
|
|
|
np.save("img.npy", original_image) |
|
|
|
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 |
|
|
|
img = img[None] |
|
|
|
|
|
x = img[:, :, None].expand(-1, -1, 3, -1, -1).to(torch.float16) |
|
|
|
|
|
x = utils.imagenet_normalize(x).to(device) |
|
with torch.no_grad(): |
|
counterfactual = model.get_counterfactual(x, points) |
|
|
|
counterfactual = counterfactual.squeeze() |
|
|
|
counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() |
|
|
|
|
|
|
|
|
|
return counterfactual |
|
|
|
|
|
run_model_button.click(run_model_on_points, [selected_points, input_image, original_image], [output_image]) |
|
|
|
|
|
demo.queue().launch(inbrowser=True, share=True) |
|
|