|
import gradio as gr |
|
import os |
|
import numpy as np |
|
import trimesh as tm |
|
from src.model import DinoV2 |
|
from src.shape_model import CSE |
|
from PIL import Image, ImageDraw |
|
import torch |
|
from torchvision import transforms |
|
import matplotlib.pyplot as plt |
|
import hashlib |
|
|
|
def image_hash(image): |
|
"""Generate a hash for an image.""" |
|
image_bytes = image.tobytes() |
|
hash_function = hashlib.sha256() |
|
hash_function.update(image_bytes) |
|
return hash_function.hexdigest() |
|
|
|
|
|
|
|
device = torch.device('cpu') |
|
models = {} |
|
for class_name in ['bear', 'horse', 'elephant']: |
|
print(f'Loading model weights for {class_name}') |
|
models[class_name] = { |
|
'image_encoder': DinoV2(16), |
|
'cse': CSE(class_name=class_name, num_basis=64, device=device) |
|
} |
|
models[class_name]['image_encoder'].load_state_dict(torch.load(f'./models/weights/{class_name}.pth', map_location=device)) |
|
models[class_name]['cse'].load_state_dict(torch.load(f'./models/weights/{class_name}_cse.pth', map_location=device)) |
|
models[class_name]['cse'].functional_basis = torch.load(f'./models/weights/{class_name}_lbo.pth', map_location=device) |
|
|
|
models[class_name]['image_encoder'] = models[class_name]['image_encoder'].to(device) |
|
models[class_name]['cse'] = models[class_name]['cse'].to(device) |
|
models[class_name]['cse'].functional_basis = models[class_name]['cse'].functional_basis.to(device) |
|
models[class_name]['cse'].weight_matrix = models[class_name]['cse'].weight_matrix.to(device) |
|
|
|
models[class_name]['shape_feats'] = models[class_name]['cse']().to(device) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
cached_features = {'bear': {}, 'horse': {}, 'elephant': {}} |
|
|
|
text_description = """ |
|
# Demo for SHIC: Shape-Image Correspondences with no Keypoint Superivision (ECCV 2024) |
|
Project website: https://www.robots.ox.ac.uk/~vgg/research/shic/ |
|
|
|
- **Step 1:** First select a class (now it defaults to 'bear') |
|
- **Step 2:** Upload an image of an animal from that class (or select one of the provided examples) |
|
- **Step 3:** Click on the image (somewhere over the object) to see the image-to-shape correspondences. You can keep clicking on the input image to see new correspondences. |
|
|
|
Notes: |
|
- You can click and drag to rotate the 3D shape |
|
- Currently the demo supports bears, horses, and elephants. Other classes coming soon! |
|
- Make sure you have selected the correct class for your image (It works cross-class too though!) |
|
""" |
|
|
|
example_images_dir = './gradio_example_images/' |
|
example_images_names = [ |
|
'bear1.png', 'bear2.png', 'bear3.png', 'winnie.png', |
|
'horse1.png', 'horse2.png', 'mylittlepony.png', 'ponyta.png', |
|
'elephant1.png', 'elephant2.png', 'dumbo.png', 'phanpy.png' |
|
] |
|
example_images = [os.path.join(example_images_dir, img) for img in example_images_names] |
|
sphere_verts_ = torch.load(f'./models/weights/sphere_verts.pth', map_location=device) |
|
sphere_faces_ = torch.load(f'./models/weights/sphere_faces.pth', map_location=device) |
|
def center_crop(img): |
|
""" |
|
Center crops an image to the target size of 224x224. |
|
""" |
|
width, height = img.size |
|
|
|
target_size = min(width, height) |
|
|
|
|
|
left = (width - target_size) // 2 |
|
top = (height - target_size) // 2 |
|
right = left + target_size |
|
bottom = top + target_size |
|
|
|
|
|
cropped_img = img.crop((left, top, right, bottom)) |
|
|
|
return cropped_img |
|
|
|
def draw_point_on_image(image, x_, y_): |
|
"""Draws a red dot on a copy of the image at the specified point.""" |
|
|
|
image_copy = image.copy() |
|
draw = ImageDraw.Draw(image_copy) |
|
x, y = x_, y_ |
|
dot_radius = image.size[0] // 40 |
|
|
|
draw.ellipse([(y-dot_radius, x-dot_radius), (y+dot_radius, x+dot_radius)], fill='red') |
|
|
|
return image_copy |
|
|
|
def rotate_y(vertices, angle_degrees): |
|
angle_radians = np.radians(angle_degrees) |
|
rotation_matrix = np.array([ |
|
[np.cos(angle_radians), 0, np.sin(angle_radians)], |
|
[0, 1, 0], |
|
[-np.sin(angle_radians), 0, np.cos(angle_radians)] |
|
]) |
|
|
|
|
|
rotated_vertices = np.dot(vertices, rotation_matrix) |
|
return rotated_vertices |
|
|
|
def make_final_mesh(verts, faces, similarities): |
|
vert_argmax = similarities.argmax(dim=1) |
|
vertex = verts[vert_argmax] |
|
color=[255, 0, 0] |
|
|
|
vertex_colors=similarities.transpose(1,0).cpu().detach().numpy() |
|
|
|
vertex_colors = plt.cm.viridis(vertex_colors)[:, 0, :3] |
|
|
|
num_verts_so_far = len(verts) |
|
|
|
|
|
|
|
|
|
|
|
scale_dot = 0.015 |
|
translation = torch.tensor(vertex, device=device).unsqueeze(0) |
|
|
|
verts_sphere = sphere_verts_ * scale_dot + translation |
|
faces_sphere = sphere_faces_ + num_verts_so_far |
|
|
|
verts_rgb_sphere = torch.tensor([color], device=device).expand(verts_sphere.shape[0], -1)[None] / 255 |
|
|
|
|
|
|
|
|
|
all_verts = np.concatenate([verts, verts_sphere.cpu().numpy()], axis=0) |
|
all_faces = np.concatenate([faces, faces_sphere.cpu().numpy()], axis=0) |
|
|
|
all_textures = np.concatenate([vertex_colors, verts_rgb_sphere.cpu().numpy()[0]], axis=0) |
|
|
|
return tm.Trimesh(vertices=all_verts, faces=all_faces, vertex_colors=all_textures) |
|
|
|
def process_mesh(image, class_name, x_, y_): |
|
x_, y_ = x_, y_ |
|
h, w = image.size |
|
|
|
x = torch.tensor(x_ * 224 / w) |
|
y = torch.tensor(y_ * 224 / h) |
|
|
|
hashed_image = image_hash(image) |
|
if hashed_image in cached_features[class_name]: |
|
feats = cached_features[class_name][hashed_image] |
|
else: |
|
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
feats = models[class_name]['image_encoder'](image_tensor.to(device)) |
|
|
|
cached_features[class_name][hashed_image] = feats |
|
|
|
|
|
|
|
sampled_feats = feats[:, :, x.long(), y.long()] |
|
similarities = torch.einsum('ik, lk -> il', sampled_feats, models[class_name]['shape_feats']) |
|
|
|
similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min()) |
|
|
|
faces = models[class_name]['cse'].shape['faces'].cpu().numpy().copy() |
|
verts = models[class_name]['cse'].shape['verts'].cpu().numpy().copy() |
|
|
|
|
|
verts = rotate_y(verts, 145) |
|
|
|
mesh = make_final_mesh(verts, faces, similarities) |
|
|
|
mesh_path = './mesh.obj' |
|
mesh.export(mesh_path) |
|
|
|
return mesh_path |
|
|
|
def update_output(image, class_name, evt: gr.SelectData): |
|
if class_name is None: |
|
class_name = 'bear' |
|
|
|
|
|
x_, y_ = evt.index[1], evt.index[0] |
|
modified_image = draw_point_on_image(image, x_, y_) |
|
mesh_path = process_mesh(image, class_name, x_, y_) |
|
return modified_image, mesh_path |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(text_description) |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
class_name = gr.Dropdown(choices=['bear', 'horse', 'elephant'], |
|
label="Select a class (defaults to 'bear')") |
|
input_img = gr.Image(label="Input", type="pil", width=256) |
|
gr.Examples( |
|
examples = example_images, |
|
inputs = [input_img], |
|
cache_examples=False, |
|
label='Feel free to use one of our provided examples!', |
|
examples_per_page=30 |
|
) |
|
with gr.Column(scale=1): |
|
output_img = gr.Image(label="Selected Point", interactive=False, height=512) |
|
with gr.Column(scale=1): |
|
output = gr.Model3D(label='Pixel to Vertex Similarities', height=512) |
|
|
|
|
|
|
|
input_img.select(update_output, [input_img, class_name], [output_img, output]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |