Spaces:
Running
Running
from os import getenv | |
from typing import Optional | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision.transforms import v2 as T | |
from dreamsim import DreamsimBackbone, DreamsimEnsemble, DreamsimModel | |
_ = torch.set_grad_enabled(False) | |
torchdev = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.set_float32_matmul_precision("high") | |
HF_TOKEN = getenv("HF_TOKEN", None) | |
MODEL_REPO = "neggles/dreamsim" | |
MODEL_VARIANTS: dict[str, str] = { | |
"Ensemble": "ensemble_vitb16", | |
"CLIP ViT-B/32": "clip_vitb32", | |
"OpenCLIP ViT-B/32": "open_clip_vitb32", | |
"DINO ViT-B/16": "dino_vitb16", | |
} | |
loaded_models: dict[str, Optional[DreamsimBackbone]] = { | |
"ensemble_vitb16": None, | |
"clip_vitb32": None, | |
"open_clip_vitb32": None, | |
"dino_vitb16": None, | |
} | |
def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
# convert to RGB/RGBA if not already (deals with palette images etc.) | |
if image.mode not in ["RGB", "RGBA"]: | |
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") | |
# convert RGBA to RGB with white background | |
if image.mode == "RGBA": | |
canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
canvas.alpha_composite(image) | |
image = canvas.convert("RGB") | |
return image | |
def pil_pad_square( | |
image: Image.Image, | |
fill: tuple[int, int, int] = (255, 255, 255), | |
) -> Image.Image: | |
w, h = image.size | |
# get the largest dimension so we can pad to a square | |
px = max(image.size) | |
# pad to square with white background | |
canvas = Image.new("RGB", (px, px), fill) | |
canvas.paste(image, ((px - w) // 2, (px - h) // 2)) | |
return canvas | |
def load_model(variant: str) -> DreamsimBackbone: | |
global loaded_models | |
if variant in MODEL_VARIANTS: | |
# resolve the repo branch for the model variant | |
variant = MODEL_VARIANTS[variant] | |
match variant: | |
case "ensemble_vitb16": | |
if loaded_models[variant] is None: | |
model: DreamsimEnsemble = DreamsimEnsemble.from_pretrained( | |
MODEL_REPO, | |
token=HF_TOKEN, | |
revision=variant, | |
) | |
model.do_resize = False | |
loaded_models[variant] = model | |
case "clip_vitb32" | "open_clip_vitb32" | "dino_vitb16": | |
if loaded_models[variant] is None: | |
model: DreamsimModel = DreamsimModel.from_pretrained( | |
MODEL_REPO, | |
token=HF_TOKEN, | |
revision=variant, | |
) | |
model.do_resize = False | |
loaded_models[variant] = model | |
case _: | |
raise ValueError(f"Unknown model variant: {variant}") | |
return loaded_models[variant] | |
def predict( | |
variant: str, | |
resize_to: Optional[int], | |
image_a: Image.Image, | |
image_b: Image.Image, | |
): | |
# Load model | |
model: DreamsimModel | DreamsimEnsemble = load_model(variant) | |
model = model.eval().to(torchdev) | |
# yeet alpha, make white background | |
image_a, image_b = pil_ensure_rgb(image_a), pil_ensure_rgb(image_b) | |
# pad to square | |
image_a, image_b = pil_pad_square(image_a), pil_pad_square(image_b) | |
# Resize images, if necessary | |
if resize_to is not None: | |
image_a.thumbnail((resize_to, resize_to), resample=Image.Resampling.BICUBIC) | |
image_b.thumbnail((resize_to, resize_to), resample=Image.Resampling.BICUBIC) | |
# Preprocess images | |
transforms = T.Compose([T.ToImage(), T.ToDtype(torch.float32, scale=True)]) | |
batch = torch.stack([transforms(image_a).unsqueeze(0), transforms(image_b).unsqueeze(0)], dim=0) | |
loss = model(batch.to(model.device, model.dtype)).cpu().item() | |
score = 1.0 - loss | |
return score, variant | |
def main(): | |
with gr.Blocks(title="DreamSIM Perceptual Similarity") as demo: | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(label="Input", type="pil", image_mode="RGB", scale=1) | |
with gr.Column(): | |
img_target = gr.Image(label="Target", type="pil", image_mode="RGB", scale=1) | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
variant = gr.Radio( | |
choices=list(MODEL_VARIANTS.keys()), label="Model Variant", value="Ensemble" | |
) | |
resize_to = gr.Dropdown(label="Resize To", choices=[224, 384, 512, None], value=224) | |
with gr.Column(): | |
score = gr.Number(label="Similarity Score", precision=8, minimum=0, maximum=1) | |
variant_out = gr.Textbox(label="Variant", interactive=False) | |
with gr.Row(): | |
clear = gr.ClearButton( | |
components=[img_input, img_target, score], variant="secondary", size="lg" | |
) | |
submit = gr.Button(value="Submit", variant="primary", size="lg") | |
submit.click( | |
predict, | |
inputs=[variant, resize_to, img_input, img_target], | |
outputs=[score, variant_out], | |
api_name=False, | |
) | |
examples = gr.Examples( | |
[ | |
["examples/img_a_1.png", "examples/ref_1.png", "Ensemble", 224], | |
["examples/img_b_1.png", "examples/ref_1.png", "Ensemble", 224], | |
], | |
inputs=[img_input, img_target, variant, resize_to], | |
) | |
demo.queue(max_size=10) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |