Spaces:
Runtime error
Runtime error
| from argparse import Namespace | |
| from glob import glob | |
| import yaml | |
| import os | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torchvision | |
| import safetensors | |
| from diffusers import AutoencoderKL, ConsistencyDecoderVAE | |
| from peft import get_peft_model, LoraConfig, set_peft_model_state_dict | |
| from huggingface_hub import snapshot_download | |
| pretrained_model_path = snapshot_download(repo_id="revp2024/revp-censorship") | |
| with open(glob(os.path.join(pretrained_model_path, 'hparams.yml'), recursive=True)[0]) as f: | |
| args = Namespace(**yaml.safe_load(f)) | |
| with open('examples/add_censorship.yaml') as f: | |
| add_censor_examples = yaml.safe_load(f) | |
| with open('examples/remove_censorship.yaml') as f: | |
| remove_censor_examples = yaml.safe_load(f) | |
| def prepare_model(): | |
| print('Loading model ...') | |
| vae_lora_config = LoraConfig( | |
| r=args.rank, | |
| lora_alpha=args.rank, | |
| init_lora_weights="gaussian", | |
| target_modules=["conv", "conv1", "conv2", | |
| "to_q", "to_k", "to_v", "to_out.0"], | |
| ) | |
| vae = AutoencoderKL.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="vae" | |
| ) | |
| vae = get_peft_model(vae, vae_lora_config) | |
| lora_weights_path = os.path.join(pretrained_model_path, f"pytorch_lora_weights.safetensors") | |
| state_dict = {} | |
| with safetensors.torch.safe_open(lora_weights_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| set_peft_model_state_dict(vae, state_dict) | |
| print('Done.') | |
| cd_vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16) | |
| vae = vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') | |
| cd_vae = cd_vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') | |
| return vae, cd_vae | |
| def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size): | |
| background, layers, _ = input_image.values() | |
| input_images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255 | |
| mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255 | |
| H, W = input_images.shape[-2:] | |
| if H > 1024 or W > 1024: | |
| H_t, W_t = H, W | |
| if H > W: | |
| H, W = 1024, int(1024 * W_t / H_t) | |
| else: | |
| H, W = int(1024 * H_t / W_t), 1024 | |
| H_q8 = (H // 8) * 8 | |
| W_q8 = (W // 8) * 8 | |
| input_images = torch.nn.functional.interpolate(input_images, (H_q8, W_q8), mode='bilinear') | |
| mask = torch.nn.functional.interpolate(mask, (H_q8, W_q8)) | |
| if soft_edges: | |
| mask = torchvision.transforms.functional.gaussian_blur(mask, soft_edge_kernel_size)[0][0] | |
| input_images = input_images.to(vae.device) | |
| if mode == 'Pixelation': | |
| censored = torch.nn.functional.avg_pool2d( | |
| input_images, pixelation_block_size) | |
| censored = torch.nn.functional.interpolate(censored, input_images.shape[-2:]) | |
| elif mode == 'Gaussian blur': | |
| censored = torchvision.transforms.functional.gaussian_blur( | |
| input_images, blur_kernel_size) | |
| elif mode == 'Black': | |
| censored = torch.zeros_like(input_images) | |
| else: | |
| raise ValueError("censor_mode has to be either `pixelation' or `gaussian_blur'") | |
| mask = mask.to(input_images.device) | |
| censored_images = input_images * (1 - mask) + censored * mask | |
| censored_images *= 255 | |
| input_images = input_images * 2 - 1 | |
| with vae.disable_adapter(): | |
| latents = vae.encode(input_images).latent_dist.mean | |
| images = vae.decode(latents, return_dict=False)[0] | |
| # denormalize | |
| images = images / 2 + 0.5 | |
| images *= 255 | |
| residuals = (images - censored_images).clamp(-args.budget, args.budget) | |
| images = (censored_images + residuals).clamp(0, 255).to(torch.uint8) | |
| gr.Info("Try to donwload/copy the censored image to the `Remove censorsip' tab") | |
| return images[0].permute(1, 2, 0).cpu().numpy() | |
| def remove_censorship(input_image, use_cd, x1, y1, x2, y2): | |
| background, layers, _ = input_image.values() | |
| images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255 | |
| mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255 | |
| images = images * (1 - mask) | |
| images = images[..., y1:y2, x1:x2] | |
| latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean | |
| if use_cd: | |
| images = cd_vae.decode(latents.to(cd_vae.dtype), return_dict=False)[0] | |
| else: | |
| with vae.disable_adapter(): | |
| images = vae.decode(latents, return_dict=False)[0] | |
| # denormalize | |
| images = images / 2 + 0.5 | |
| images *= 255 | |
| images = images.clamp(0, 255).to(torch.uint8) | |
| return images[0].permute(1, 2, 0).cpu().numpy() | |
| # @@@@@@@ Start of the program @@@@@@@@ | |
| vae, cd_vae = prepare_model() | |
| css = ''' | |
| .my-disabled { | |
| background-color: #eee; | |
| } | |
| .my-disabled input { | |
| background-color: #eee; | |
| } | |
| ''' | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown('# ReVP: Reversible Visual Processing with Latent Models') | |
| gr.Markdown('### Check out our project page for more info: https://revp2024.github.io') | |
| with gr.Tab('Add censorship'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.ImageEditor(brush=gr.Brush(default_size=100)) | |
| with gr.Accordion('Options', open=False) as options_accord: | |
| mode = gr.Radio(label='Mode', choices=['Pixelation', 'Gaussian blur', 'Black'], | |
| value='Pixelation', interactive=True) | |
| pixelation_block_size = gr.Slider(label='Block size', minimum=10, maximum=40, value=25, step=1, interactive=True) | |
| blur_kernel_size = gr.Slider(label='Blur kernel size', minimum=21, maximum=151, value=85, step=2, interactive=True, visible=False) | |
| def change_mode(mode): | |
| if mode == 'Gaussian blur': | |
| return gr.Slider(visible=False), gr.Slider(visible=True), gr.Accordion(open=True) | |
| elif mode == 'Pixelation': | |
| return gr.Slider(visible=True), gr.Slider(visible=False), gr.Accordion(open=True) | |
| elif mode == 'Black': | |
| return gr.Slider(visible=False), gr.Slider(visible=False), gr.Accordion(open=True) | |
| else: | |
| raise NotImplementedError | |
| mode.select(change_mode, mode, [pixelation_block_size, blur_kernel_size, options_accord]) | |
| with gr.Row(variant='panel'): | |
| soft_edges = gr.Checkbox(label='Soft edges', value=True, interactive=True, scale=1) | |
| soft_edge_kernel_size = gr.Slider(label='Soft edge kernel size', minimum=21, maximum=49, value=35, step=2, interactive=True, visible=True, scale=2) | |
| def change_soft_edges(soft_edges): | |
| return gr.Slider(visible=True if soft_edges else False), gr.Accordion(open=True) | |
| soft_edges.change(change_soft_edges, soft_edges, [soft_edge_kernel_size, options_accord]) | |
| submit_btn = gr.Button('Submit') | |
| output_image = gr.Image(label='Censored', show_download_button=True) | |
| submit_btn.click( | |
| fn=add_censorship, | |
| inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size], | |
| outputs=output_image | |
| ) | |
| gr.Examples( | |
| examples=add_censor_examples, | |
| fn=add_censorship, | |
| inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size], | |
| outputs=output_image, | |
| cache_examples=False, | |
| ) | |
| with gr.Tab('Remove censorship'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.ImageEditor() | |
| use_cd = gr.Checkbox(label='Use Consistency Decoder (slower)') | |
| with gr.Accordion('Manual cropping', open=False): | |
| with gr.Row(): | |
| with gr.Row(): | |
| x1 = gr.Number(value=0, label='x1') | |
| y1 = gr.Number(value=0, label='y1') | |
| with gr.Row(): | |
| x2_ = gr.Number(value=10000, label='x2', interactive=False, elem_classes='my-disabled') | |
| y1_ = gr.Number(value=0, label='y1', interactive=False, elem_classes='my-disabled') | |
| with gr.Row(): | |
| with gr.Row(): | |
| x1_ =gr.Number(value=0, label='x1', elem_classes='my-disabled') | |
| y2_ = gr.Number(value=10000, label='y2', elem_classes='my-disabled') | |
| with gr.Row(): | |
| x2 = gr.Number(value=10000, label='x2') | |
| y2 = gr.Number(value=10000, label='y2') | |
| submit_btn = gr.Button('Submit') | |
| output_image = gr.Image(label='Uncensored') | |
| submit_btn.click( | |
| fn=remove_censorship, | |
| inputs=[input_image, use_cd, x1, y1, x2, y2], | |
| outputs=output_image | |
| ) | |
| gr.Examples( | |
| examples=remove_censor_examples, | |
| fn=remove_censorship, | |
| inputs=[input_image, use_cd, x1, y1, x2, y2], | |
| outputs=output_image, | |
| cache_examples=False, | |
| ) | |
| # sync coordinate on changed | |
| x1.change(lambda x : x, x1, x1_) | |
| x2.change(lambda x : x, x2, x2_) | |
| y1.change(lambda x : x, y1, y1_) | |
| y2.change(lambda x : x, y2, y2_) | |
| if __name__ == '__main__': | |
| demo.queue(4) | |
| demo.launch() | |