Spaces:
Build error
Build error
| import os | |
| import warnings | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from huggingface_hub import Repository | |
| from hydra import compose, initialize | |
| from PIL import Image | |
| from torchvision import transforms as T | |
| from models.builder import build_model | |
| from segmentation.datasets import PascalVOCDataset | |
| from visualization import mask2rgb | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| # Constants | |
| CHECKPOINT_PATH = "clip-dinoiser/checkpoints/last.pt" | |
| CONFIG_PATH = "configs" | |
| DINOCLIP_CONFIG = "clip_dinoiser.yaml" | |
| COLORS = [ | |
| (0, 255, 0), | |
| (255, 0, 0), | |
| (0, 255, 255), | |
| (255, 0, 255), | |
| (255, 255, 0), | |
| (250, 128, 114), | |
| (255, 165, 0), | |
| (0, 128, 0), | |
| (144, 238, 144), | |
| (175, 238, 238), | |
| (0, 191, 255), | |
| (0, 128, 0), | |
| (138, 43, 226), | |
| (255, 0, 255), | |
| (255, 215, 0), | |
| (0, 0, 255), | |
| ] | |
| # Initialize Hydra | |
| initialize(config_path=CONFIG_PATH, version_base=None) | |
| # Configuration and Model Initialization | |
| def load_model(): | |
| Repository( | |
| local_dir="clip-dinoiser", | |
| clone_from="ariG23498/clip-dinoiser", | |
| use_auth_token=os.environ.get("token"), | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) | |
| cfg = compose(config_name=DINOCLIP_CONFIG) | |
| model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device) | |
| model.clip_backbone.decode_head.use_templates = False | |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
| return model.eval() | |
| def run_clip_dinoiser(input_image, text_prompts, model, device, colors): | |
| # Resize the input image | |
| image = input_image.resize((350, 350)) | |
| image = image.convert("RGB") | |
| text_prompts = text_prompts.split(",") | |
| palette = colors[: len(text_prompts)] | |
| model.clip_backbone.decode_head.update_vocab(text_prompts) | |
| model.to(device) | |
| img_tens = T.PILToTensor()(image).unsqueeze(0).to(device) / 255.0 | |
| h, w = img_tens.shape[-2:] | |
| output = model(img_tens).cpu() | |
| output = F.interpolate( | |
| output, | |
| scale_factor=model.clip_backbone.backbone.patch_size, | |
| mode="bilinear", | |
| align_corners=False, | |
| )[..., :h, :w] | |
| output = output[0].argmax(dim=0) | |
| mask = mask2rgb(output, palette) | |
| alpha = 0.5 | |
| blend = (alpha * np.array(image) / 255.0) + ((1 - alpha) * mask / 255.0) | |
| h_text = [(text, f"{idx}") for idx, text in enumerate(text_prompts)] | |
| return blend, mask, h_text | |
| def create_color_map(colors): | |
| return { | |
| f"{color_id}": f"#{hex(color[0])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[2])[2:].zfill(2)}" | |
| for color_id, color in enumerate(colors) | |
| } | |
| def setup_gradio_interface(model, device, colors, color_map): | |
| block = gr.Blocks() | |
| with block: | |
| gr.Markdown("<h1><center>CLIP-DINOiser<h1><center>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| text_prompts = gr.Textbox(label="Enter comma-separated prompts") | |
| run_button = gr.Button(value="Run") | |
| with gr.Column(): | |
| with gr.Row(): | |
| overlay_mask = gr.Image(type="numpy", label="Overlay Mask") | |
| only_mask = gr.Image(type="numpy", label="Segmentation Mask") | |
| h_text = gr.HighlightedText( | |
| label="Labels", | |
| combine_adjacent=False, | |
| show_legend=False, | |
| color_map=color_map, | |
| ) | |
| run_button.click( | |
| fn=lambda img, prompts: run_clip_dinoiser( | |
| img, prompts, model, device, colors | |
| ), | |
| inputs=[input_image, text_prompts], | |
| outputs=[overlay_mask, only_mask, h_text], | |
| ) | |
| gr.Examples( | |
| examples=[["vintage_bike.jpeg", "background, vintage bike, leather bag"]], | |
| inputs=[input_image, text_prompts], | |
| outputs=[overlay_mask, only_mask, h_text], | |
| fn=lambda img, prompts: run_clip_dinoiser( | |
| img, prompts, model, device, colors | |
| ), | |
| cache_examples=True, | |
| label="Try this example input!", | |
| ) | |
| return block | |
| if __name__ == "__main__": | |
| model = load_model() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| color_map = create_color_map(COLORS) | |
| gradio_interface = setup_gradio_interface(model, device, COLORS, color_map) | |
| gradio_interface.launch(share=False, show_api=False, show_error=True) | |