Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchvision.models as models | |
| from huggingface_hub import hf_hub_url, cached_download | |
| from configs.default import get_cfg_defaults | |
| from modeling.build import build_model | |
| from utils.data_utils import linear_scaling | |
| url = hf_hub_url(repo_id="birdortyedi/cifr", filename="cifr.pth") | |
| model_path = cached_download(url) | |
| cfg = get_cfg_defaults() | |
| cfg.MODEL.CKPT = model_path | |
| net, _ = build_model(cfg) | |
| net = net.eval() | |
| vgg16 = models.vgg16(pretrained=True).features.eval() | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| def load_checkpoints_from_ckpt(ckpt_path, device): | |
| checkpoints = torch.load(ckpt_path, map_location=device) | |
| net.load_state_dict(checkpoints["ifr"]) | |
| load_checkpoints_from_ckpt(cfg.MODEL.CKPT, device) | |
| def filter_removal(img): | |
| arr = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) | |
| arr = torch.tensor(arr).float() / 255. | |
| arr = linear_scaling(arr) | |
| with torch.no_grad(): | |
| feat = vgg16(arr) | |
| out, _ = net(arr, feat) | |
| out = torch.clamp(out, max=1., min=0.) | |
| return out.squeeze(0).permute(1, 2, 0).numpy() | |
| title = "Contrastive Instagram Filter Removal (CIFR)" | |
| description = "This is the demo for CIFR, contrastive strategy for filter removal on fashionable images on Instagram. " \ | |
| "To use it, simply upload your filtered image, or click one of the examples to load them." | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.07486'>Contrastive Instagram Filter Removal (CIFR)</a> | <a href='https://github.com/birdortyedi/cifr-pytorch'>Github Repo</a></p>" | |
| gr.Interface( | |
| filter_removal, | |
| gr.inputs.Image(shape=(256, 256)), | |
| gr.outputs.Image(), | |
| title=title, | |
| description=description, | |
| article=article, | |
| allow_flagging=False, | |
| examples_per_page=17, | |
| examples=[ | |
| ["images/examples/98_He-Fe.jpg"], | |
| ["images/examples/2_Brannan.jpg"], | |
| ["images/examples/12_Toaster.jpg"], | |
| ["images/examples/18_Gingham.jpg"], | |
| ["images/examples/11_Sutro.jpg"], | |
| ["images/examples/9_Lo-Fi.jpg"], | |
| ["images/examples/3_Mayfair.jpg"], | |
| ["images/examples/4_Hudson.jpg"], | |
| ["images/examples/5_Amaro.jpg"], | |
| ["images/examples/6_1977.jpg"], | |
| ["images/examples/8_Valencia.jpg"], | |
| ["images/examples/16_Lo-Fi.jpg"], | |
| ["images/examples/10_Nashville.jpg"], | |
| ["images/examples/15_X-ProII.jpg"], | |
| ["images/examples/14_Willow.jpg"], | |
| ["images/examples/30_Perpetua.jpg"], | |
| ["images/examples/1_Clarendon.jpg"], | |
| ] | |
| ).launch() | |