Spaces:
Runtime error
Runtime error
import os | |
os.system("pip uninstall -y mmcv-full") | |
os.system("pip uninstall -y mmsegmentation") | |
os.system("pip install ./mmcv_full-1.5.0-cp310-cp310-linux_x86_64.whl") | |
os.system("pip install -r requirements-extras.txt") | |
# os.system("cp /home/user/data/dinov2_vitg14_ade20k_m2f.pth /home/user/.cache/torch/hub/checkpoints/dinov2_vitg14_ade20k_m2f.pth") | |
import gradio as gr | |
import base64 | |
import cv2 | |
import math | |
import itertools | |
from functools import partial | |
from PIL import Image | |
import numpy as np | |
import pandas as pd | |
import dinov2.eval.segmentation.utils.colormaps as colormaps | |
import torch | |
import torch.nn.functional as F | |
from mmseg.apis import init_segmentor, inference_segmentor | |
import dinov2.eval.segmentation.models | |
import dinov2.eval.segmentation_m2f.models.segmentors | |
import urllib | |
import mmcv | |
from mmcv.runner import load_checkpoint | |
model = None | |
model_loaded = False | |
DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" | |
CONFIG_URL = f"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f_config.py" | |
CHECKPOINT_URL = f"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth" | |
def load_config_from_url(url: str) -> str: | |
with urllib.request.urlopen(url) as f: | |
return f.read().decode() | |
cfg_str = load_config_from_url(CONFIG_URL) | |
cfg = mmcv.Config.fromstring(cfg_str, file_format=".py") | |
DATASET_COLORMAPS = { | |
"ade20k": colormaps.ADE20K_COLORMAP, | |
"voc2012": colormaps.VOC2012_COLORMAP, | |
} | |
colormap = DATASET_COLORMAPS["ade20k"] | |
flattened = np.array(colormap).flatten() | |
zeros = np.zeros(768) | |
zeros[:flattened.shape[0]] = flattened | |
colorMap = list(zeros.astype('uint8')) | |
model = init_segmentor(cfg) | |
load_checkpoint(model, CHECKPOINT_URL, map_location="cpu") | |
model.cuda() | |
model.eval() | |
class CenterPadding(torch.nn.Module): | |
def __init__(self, multiple): | |
super().__init__() | |
self.multiple = multiple | |
def _get_pad(self, size): | |
new_size = math.ceil(size / self.multiple) * self.multiple | |
pad_size = new_size - size | |
pad_size_left = pad_size // 2 | |
pad_size_right = pad_size - pad_size_left | |
return pad_size_left, pad_size_right | |
def forward(self, x): | |
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) | |
output = F.pad(x, pads) | |
return output | |
def create_segmenter(cfg, backbone_model): | |
model = init_segmentor(cfg) | |
model.backbone.forward = partial( | |
backbone_model.get_intermediate_layers, | |
n=cfg.model.backbone.out_indices, | |
reshape=True, | |
) | |
if hasattr(backbone_model, "patch_size"): | |
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone_model.patch_size)(x[0])) | |
model.init_weights() | |
return model | |
def render_segmentation(segmentation_logits, dataset): | |
colormap_array = np.array(colormap, dtype=np.uint8) | |
segmentation_logits += 1 | |
segmented_image = Image.fromarray(segmentation_logits) | |
segmented_image.putpalette(colorMap) | |
unique_labels = np.unique(segmentation_logits) | |
colormap_array = colormap_array[unique_labels] | |
df = pd.read_csv("labelmap.txt", sep="\t") | |
html_output = '<div style="display: flex; flex-wrap: wrap;">' | |
import matplotlib.pyplot as plt | |
for idx, color in enumerate(colormap_array): | |
color_box = np.zeros((20, 20, 3), dtype=np.uint8) | |
color_box[:, :] = color | |
color_box = cv2.cvtColor(color_box, cv2.COLOR_RGB2BGR) | |
_, img_data = cv2.imencode(".jpg", color_box) | |
img_base64 = base64.b64encode(img_data).decode("utf-8") | |
img_data_uri = f"data:image/jpg;base64,{img_base64}" | |
html_output += f'<div style="margin: 10px;"><img src="{img_data_uri}" /><p>{df.iloc[unique_labels[idx]-1]["Name"]}</p></div>' | |
html_output += "</div>" | |
return segmented_image, html_output | |
def predict(image_file): | |
array = np.array(image_file)[:, :, ::-1] # BGR | |
segmentation_logits = inference_segmentor(model, array)[0] | |
segmentation_logits = segmentation_logits.astype(np.uint8) | |
segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k") | |
return segmented_image, html_output | |
description = "Gradio demo for Semantic segmentation. To use it, simply upload your image" | |
demo = gr.Interface( | |
title="Semantic Segmentation - DinoV2", | |
fn=predict, | |
inputs=gr.inputs.Image(), | |
outputs=[gr.outputs.Image(type="pil"), gr.outputs.HTML()], | |
examples=["example_1.jpg", "example_2.jpg"], | |
cache_examples=False, | |
description=description, | |
) | |
demo.launch() |