import gradio as gr import os import numpy as np import trimesh as tm from src.model import DinoV2 from src.shape_model import CSE from PIL import Image, ImageDraw import torch from torchvision import transforms import matplotlib.pyplot as plt import hashlib def image_hash(image): """Generate a hash for an image.""" image_bytes = image.tobytes() hash_function = hashlib.sha256() hash_function.update(image_bytes) return hash_function.hexdigest() #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cpu') models = {} for class_name in ['bear', 'horse', 'elephant']: print(f'Loading model weights for {class_name}') models[class_name] = { 'image_encoder': DinoV2(16), 'cse': CSE(class_name=class_name, num_basis=64, device=device) } models[class_name]['image_encoder'].load_state_dict(torch.load(f'./models/weights/{class_name}.pth', map_location=device)) models[class_name]['cse'].load_state_dict(torch.load(f'./models/weights/{class_name}_cse.pth', map_location=device)) models[class_name]['cse'].functional_basis = torch.load(f'./models/weights/{class_name}_lbo.pth', map_location=device) models[class_name]['image_encoder'] = models[class_name]['image_encoder'].to(device) models[class_name]['cse'] = models[class_name]['cse'].to(device) models[class_name]['cse'].functional_basis = models[class_name]['cse'].functional_basis.to(device) models[class_name]['cse'].weight_matrix = models[class_name]['cse'].weight_matrix.to(device) models[class_name]['shape_feats'] = models[class_name]['cse']().to(device) # Convert PIL image to a format your model expects (e.g., torch.Tensor) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) cached_features = {'bear': {}, 'horse': {}, 'elephant': {}} text_description = """ # Demo for SHIC: Shape-Image Correspondences with no Keypoint Superivision (ECCV 2024) Project website: https://www.robots.ox.ac.uk/~vgg/research/shic/ - **Step 1:** First select a class (now it defaults to 'bear') - **Step 2:** Upload an image of an animal from that class (or select one of the provided examples) - **Step 3:** Click on the image (somewhere over the object) to see the image-to-shape correspondences. You can keep clicking on the input image to see new correspondences. Notes: - You can click and drag to rotate the 3D shape - Currently the demo supports bears, horses, and elephants. Other classes coming soon! - Make sure you have selected the correct class for your image (It works cross-class too though!) """ example_images_dir = './gradio_example_images/' example_images_names = [ 'bear1.png', 'bear2.png', 'bear3.png', 'winnie.png', 'horse1.png', 'horse2.png', 'mylittlepony.png', 'ponyta.png', 'elephant1.png', 'elephant2.png', 'dumbo.png', 'phanpy.png' ] example_images = [os.path.join(example_images_dir, img) for img in example_images_names] sphere_verts_ = torch.load(f'./models/weights/sphere_verts.pth', map_location=device) sphere_faces_ = torch.load(f'./models/weights/sphere_faces.pth', map_location=device) def center_crop(img): """ Center crops an image to the target size of 224x224. """ width, height = img.size # Get dimensions # Calculate the target size for center cropping target_size = min(width, height) # Calculate the coordinates for center cropping left = (width - target_size) // 2 top = (height - target_size) // 2 right = left + target_size bottom = top + target_size # Perform center cropping cropped_img = img.crop((left, top, right, bottom)) return cropped_img def draw_point_on_image(image, x_, y_): """Draws a red dot on a copy of the image at the specified point.""" # Make a copy of the image to avoid altering the original image_copy = image.copy() draw = ImageDraw.Draw(image_copy) x, y = x_, y_ # Adjust these based on the actual structure of `point` dot_radius = image.size[0] // 40 # Draw a red dot draw.ellipse([(y-dot_radius, x-dot_radius), (y+dot_radius, x+dot_radius)], fill='red') return image_copy def rotate_y(vertices, angle_degrees): angle_radians = np.radians(angle_degrees) rotation_matrix = np.array([ [np.cos(angle_radians), 0, np.sin(angle_radians)], [0, 1, 0], [-np.sin(angle_radians), 0, np.cos(angle_radians)] ]) # Assuming vertices is a numpy array of shape (N, 3) rotated_vertices = np.dot(vertices, rotation_matrix) return rotated_vertices def make_final_mesh(verts, faces, similarities): vert_argmax = similarities.argmax(dim=1) vertex = verts[vert_argmax] color=[255, 0, 0] vertex_colors=similarities.transpose(1,0).cpu().detach().numpy() # to viridis color map vertex_colors = plt.cm.viridis(vertex_colors)[:, 0, :3] num_verts_so_far = len(verts) # Create a sphere mesh # Scale and translate the sphere to the desired location and size scale_dot = 0.015 # radius of the sphere translation = torch.tensor(vertex, device=device).unsqueeze(0) # desired location verts_sphere = sphere_verts_ * scale_dot + translation # scale and translate vertices faces_sphere = sphere_faces_ + num_verts_so_far # faces are the same verts_rgb_sphere = torch.tensor([color], device=device).expand(verts_sphere.shape[0], -1)[None] / 255 # [1, N, 3] # verts and all sphere verts # concat np arrays verts + verts_sphere.cpu().numpy() (4936,3) (2562,3) all_verts = np.concatenate([verts, verts_sphere.cpu().numpy()], axis=0) all_faces = np.concatenate([faces, faces_sphere.cpu().numpy()], axis=0) all_textures = np.concatenate([vertex_colors, verts_rgb_sphere.cpu().numpy()[0]], axis=0) return tm.Trimesh(vertices=all_verts, faces=all_faces, vertex_colors=all_textures) def process_mesh(image, class_name, x_, y_): x_, y_ = x_, y_ h, w = image.size x = torch.tensor(x_ * 224 / w) y = torch.tensor(y_ * 224 / h) hashed_image = image_hash(image) if hashed_image in cached_features[class_name]: feats = cached_features[class_name][hashed_image] else: image_tensor = transform(image).unsqueeze(0) # Predict texture feats = models[class_name]['image_encoder'](image_tensor.to(device)) cached_features[class_name][hashed_image] = feats # print('feats shape', feats.shape) sampled_feats = feats[:, :, x.long(), y.long()] similarities = torch.einsum('ik, lk -> il', sampled_feats, models[class_name]['shape_feats']) # normalize similarities similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min()) faces = models[class_name]['cse'].shape['faces'].cpu().numpy().copy() verts = models[class_name]['cse'].shape['verts'].cpu().numpy().copy() # rotate the shape 235 verts = rotate_y(verts, 145) mesh = make_final_mesh(verts, faces, similarities) # save as obj mesh_path = './mesh.obj' mesh.export(mesh_path) return mesh_path def update_output(image, class_name, evt: gr.SelectData): if class_name is None: class_name = 'bear' # This function will be triggered when an image is clicked. # evt contains the click event data, including the coordinates. x_, y_ = evt.index[1], evt.index[0] modified_image = draw_point_on_image(image, x_, y_) mesh_path = process_mesh(image, class_name, x_, y_) return modified_image, mesh_path # Replace with the actual model path with gr.Blocks() as demo: # choose a class gr.Markdown(text_description) with gr.Row(variant="panel"): with gr.Column(scale=1): class_name = gr.Dropdown(choices=['bear', 'horse', 'elephant'], label="Select a class (defaults to 'bear')") input_img = gr.Image(label="Input", type="pil", width=256) gr.Examples( examples = example_images, inputs = [input_img], cache_examples=False, label='Feel free to use one of our provided examples!', examples_per_page=30 ) with gr.Column(scale=1): output_img = gr.Image(label="Selected Point", interactive=False, height=512) with gr.Column(scale=1): output = gr.Model3D(label='Pixel to Vertex Similarities', height=512) input_img.select(update_output, [input_img, class_name], [output_img, output]) if __name__ == "__main__": demo.launch(share=True)