Ahsen Khaliq commited on
Commit
cde81bb
1 Parent(s): cc017e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -47
app.py CHANGED
@@ -1,50 +1,115 @@
1
  import os
 
2
  os.system('pip install gradio==2.3.0a0')
3
- os.system('pip freeze')
4
- os.system('nvidia-smi')
5
- import torch
6
  import gradio as gr
7
- from moviepy.editor import *
8
-
9
- model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"
10
-
11
- convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
12
-
13
- def inference(video):
14
- #clip = VideoFileClip(video).subclip(0, 5)
15
- #clip.write_videofile("output.mp4")
16
- #os.system('ffmpeg -ss 00:00:00 -i '+ video +' -to 00:00:05 -c copy -y output.mp4')
17
- clip = VideoFileClip(video)
18
- print(clip.duration)
19
- if clip.duration > 10:
20
- return 'trim.mp4',"trim.mp4","trim.mp4"
21
- convert_video(
22
- model, # The loaded model, can be on any device (cpu or cuda).
23
- input_source=video, # A video file or an image sequence directory.
24
- input_resize=(512,512), # [Optional] Resize the input (also the output).
25
- downsample_ratio=None, # [Optional] If None, make downsampled max size be 512px.
26
- output_type='video', # Choose "video" or "png_sequence"
27
- output_composition='com.mp4', # File path if video; directory path if png sequence.
28
- output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction.
29
- output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction.
30
- output_video_mbps=4, # Output video mbps. Not needed for png sequence.
31
- seq_chunk=8, # Process n frames at once for better parallelism.
32
- num_workers=1, # Only for image sequence input. Reader threads.
33
- progress=True # Print conversion progress.
34
- )
35
- return 'com.mp4',"pha.mp4","fgr.mp4"
36
-
37
- title = "Robust Video Matting"
38
- description = "Gradio demo for Robust Video Matting. To use it, simply upload your video, currently only mp4 and ogg formats are supported. Please trim video to 10 seconds or less. Read more at the links below."
39
-
40
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"
41
-
42
- gr.Interface(
43
- inference,
44
- gr.inputs.Video(label="Input"),
45
- [gr.outputs.Video(label="Output Composition"),gr.outputs.Video(label="Output Alpha"),gr.outputs.Video(label="Output Foreground")],
46
- title=title,
47
- description=description,
48
- article=article,
49
- enable_queue=True).launch(debug=True)
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
3
  os.system('pip install gradio==2.3.0a0')
 
 
 
4
  import gradio as gr
5
+ os.system('git clone https://github.com/openai/CLIP')
6
+ os.system('git clone https://github.com/openai/guided-diffusion')
7
+ os.system('pip install -e ./CLIP')
8
+ os.system('pip install -e ./guided-diffusion')
9
+ os.system('pip install kornia')
10
+ os.system("curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'")
11
+ # Imports
12
+ import math
13
+ import sys
14
+ #from IPython import display
15
+ from kornia import augmentation, filters
16
+ from PIL import Image
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+ from torchvision import transforms
21
+ from torchvision.transforms import functional as TF
22
+ from tqdm.notebook import tqdm
23
+ sys.path.append('./CLIP')
24
+ sys.path.append('./guided-diffusion')
25
+ import clip
26
+ from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
27
+ # Model settings
28
+ model_config = model_and_diffusion_defaults()
29
+ model_config.update({
30
+ 'attention_resolutions': '32, 16, 8',
31
+ 'class_cond': False,
32
+ 'diffusion_steps': 1000,
33
+ 'rescale_timesteps': False,
34
+ 'timestep_respacing': '500',
35
+ 'image_size': 256,
36
+ 'learn_sigma': True,
37
+ 'noise_schedule': 'linear',
38
+ 'num_channels': 256,
39
+ 'num_head_channels': 64,
40
+ 'num_res_blocks': 2,
41
+ 'resblock_updown': True,
42
+ 'use_fp16': True,
43
+ 'use_scale_shift_norm': True,
44
+ })
45
+ # Load models and define necessary functions
46
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
47
+ print('Using device:', device)
48
+ model, diffusion = create_model_and_diffusion(**model_config)
49
+ model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
50
+ model.eval().requires_grad_(False).to(device)
51
+ if model_config['use_fp16']:
52
+ model.convert_to_fp16()
53
+ clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
54
+ clip_size = clip_model.visual.input_resolution
55
+ normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
56
+ std=[0.26862954, 0.26130258, 0.27577711])
57
+ def spherical_dist_loss(x, y):
58
+ x = F.normalize(x, dim=-1)
59
+ y = F.normalize(y, dim=-1)
60
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
61
+
62
+ def inference(text):
63
+ prompt = text
64
+ batch_size = 1
65
+ clip_guidance_scale = 2750
66
+ seed = 0
67
+
68
+ if seed is not None:
69
+ torch.manual_seed(seed)
70
+
71
+ text_embed = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
72
+
73
+ translate_by = 8 / clip_size
74
+ if translate_by:
75
+ aug = augmentation.RandomAffine(0, (translate_by, translate_by),
76
+ padding_mode='border', p=1)
77
+ else:
78
+ aug = nn.Identity()
79
+
80
+ cur_t = diffusion.num_timesteps - 1
81
+
82
+ def cond_fn(x, t, y=None):
83
+ with torch.enable_grad():
84
+ x_in = x.detach().requires_grad_()
85
+ sigma = min(24, diffusion.sqrt_recipm1_alphas_cumprod[cur_t] / 4)
86
+ kernel_size = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3)
87
+ x_blur = filters.gaussian_blur2d(x_in, (kernel_size, kernel_size), (sigma, sigma))
88
+ clip_in = F.interpolate(aug(x_blur.add(1).div(2)), (clip_size, clip_size),
89
+ mode='bilinear', align_corners=False)
90
+ image_embed = clip_model.encode_image(normalize(clip_in)).float()
91
+ losses = spherical_dist_loss(image_embed, text_embed)
92
+ grad = -torch.autograd.grad(losses.sum(), x_in)[0]
93
+ return grad * clip_guidance_scale
94
+
95
+ samples = diffusion.p_sample_loop_progressive(
96
+ model,
97
+ (batch_size, 3, model_config['image_size'], model_config['image_size']),
98
+ clip_denoised=True,
99
+ model_kwargs={},
100
+ cond_fn=cond_fn,
101
+ progress=True,
102
+ )
103
+
104
+ for i, sample in enumerate(samples):
105
+ cur_t -= 1
106
+ if i % 100 == 0 or cur_t == -1:
107
+ print()
108
+ for j, image in enumerate(sample['pred_xstart']):
109
+ filename = f'progress_{j:05}.png'
110
+ TF.to_pil_image(image.add(1).div(2).clamp(0, 1)).save(filename)
111
+ tqdm.write(f'Step {i}, output {j}:')
112
+ #display.display(display.Image(filename))
113
+ return 'progress_00000.png'
114
+ iface = gr.Interface(inference, inputs="text", outputs="image")
115
+ iface.launch()