Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
•
cde81bb
1
Parent(s):
cc017e3
Update app.py
Browse files
app.py
CHANGED
@@ -1,50 +1,115 @@
|
|
1 |
import os
|
|
|
2 |
os.system('pip install gradio==2.3.0a0')
|
3 |
-
os.system('pip freeze')
|
4 |
-
os.system('nvidia-smi')
|
5 |
-
import torch
|
6 |
import gradio as gr
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
os.system('pip install gradio==2.3.0a0')
|
|
|
|
|
|
|
4 |
import gradio as gr
|
5 |
+
os.system('git clone https://github.com/openai/CLIP')
|
6 |
+
os.system('git clone https://github.com/openai/guided-diffusion')
|
7 |
+
os.system('pip install -e ./CLIP')
|
8 |
+
os.system('pip install -e ./guided-diffusion')
|
9 |
+
os.system('pip install kornia')
|
10 |
+
os.system("curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'")
|
11 |
+
# Imports
|
12 |
+
import math
|
13 |
+
import sys
|
14 |
+
#from IPython import display
|
15 |
+
from kornia import augmentation, filters
|
16 |
+
from PIL import Image
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from torchvision import transforms
|
21 |
+
from torchvision.transforms import functional as TF
|
22 |
+
from tqdm.notebook import tqdm
|
23 |
+
sys.path.append('./CLIP')
|
24 |
+
sys.path.append('./guided-diffusion')
|
25 |
+
import clip
|
26 |
+
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
|
27 |
+
# Model settings
|
28 |
+
model_config = model_and_diffusion_defaults()
|
29 |
+
model_config.update({
|
30 |
+
'attention_resolutions': '32, 16, 8',
|
31 |
+
'class_cond': False,
|
32 |
+
'diffusion_steps': 1000,
|
33 |
+
'rescale_timesteps': False,
|
34 |
+
'timestep_respacing': '500',
|
35 |
+
'image_size': 256,
|
36 |
+
'learn_sigma': True,
|
37 |
+
'noise_schedule': 'linear',
|
38 |
+
'num_channels': 256,
|
39 |
+
'num_head_channels': 64,
|
40 |
+
'num_res_blocks': 2,
|
41 |
+
'resblock_updown': True,
|
42 |
+
'use_fp16': True,
|
43 |
+
'use_scale_shift_norm': True,
|
44 |
+
})
|
45 |
+
# Load models and define necessary functions
|
46 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
47 |
+
print('Using device:', device)
|
48 |
+
model, diffusion = create_model_and_diffusion(**model_config)
|
49 |
+
model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
|
50 |
+
model.eval().requires_grad_(False).to(device)
|
51 |
+
if model_config['use_fp16']:
|
52 |
+
model.convert_to_fp16()
|
53 |
+
clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
|
54 |
+
clip_size = clip_model.visual.input_resolution
|
55 |
+
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
56 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
57 |
+
def spherical_dist_loss(x, y):
|
58 |
+
x = F.normalize(x, dim=-1)
|
59 |
+
y = F.normalize(y, dim=-1)
|
60 |
+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
61 |
+
|
62 |
+
def inference(text):
|
63 |
+
prompt = text
|
64 |
+
batch_size = 1
|
65 |
+
clip_guidance_scale = 2750
|
66 |
+
seed = 0
|
67 |
+
|
68 |
+
if seed is not None:
|
69 |
+
torch.manual_seed(seed)
|
70 |
+
|
71 |
+
text_embed = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
|
72 |
+
|
73 |
+
translate_by = 8 / clip_size
|
74 |
+
if translate_by:
|
75 |
+
aug = augmentation.RandomAffine(0, (translate_by, translate_by),
|
76 |
+
padding_mode='border', p=1)
|
77 |
+
else:
|
78 |
+
aug = nn.Identity()
|
79 |
+
|
80 |
+
cur_t = diffusion.num_timesteps - 1
|
81 |
+
|
82 |
+
def cond_fn(x, t, y=None):
|
83 |
+
with torch.enable_grad():
|
84 |
+
x_in = x.detach().requires_grad_()
|
85 |
+
sigma = min(24, diffusion.sqrt_recipm1_alphas_cumprod[cur_t] / 4)
|
86 |
+
kernel_size = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3)
|
87 |
+
x_blur = filters.gaussian_blur2d(x_in, (kernel_size, kernel_size), (sigma, sigma))
|
88 |
+
clip_in = F.interpolate(aug(x_blur.add(1).div(2)), (clip_size, clip_size),
|
89 |
+
mode='bilinear', align_corners=False)
|
90 |
+
image_embed = clip_model.encode_image(normalize(clip_in)).float()
|
91 |
+
losses = spherical_dist_loss(image_embed, text_embed)
|
92 |
+
grad = -torch.autograd.grad(losses.sum(), x_in)[0]
|
93 |
+
return grad * clip_guidance_scale
|
94 |
+
|
95 |
+
samples = diffusion.p_sample_loop_progressive(
|
96 |
+
model,
|
97 |
+
(batch_size, 3, model_config['image_size'], model_config['image_size']),
|
98 |
+
clip_denoised=True,
|
99 |
+
model_kwargs={},
|
100 |
+
cond_fn=cond_fn,
|
101 |
+
progress=True,
|
102 |
+
)
|
103 |
+
|
104 |
+
for i, sample in enumerate(samples):
|
105 |
+
cur_t -= 1
|
106 |
+
if i % 100 == 0 or cur_t == -1:
|
107 |
+
print()
|
108 |
+
for j, image in enumerate(sample['pred_xstart']):
|
109 |
+
filename = f'progress_{j:05}.png'
|
110 |
+
TF.to_pil_image(image.add(1).div(2).clamp(0, 1)).save(filename)
|
111 |
+
tqdm.write(f'Step {i}, output {j}:')
|
112 |
+
#display.display(display.Image(filename))
|
113 |
+
return 'progress_00000.png'
|
114 |
+
iface = gr.Interface(inference, inputs="text", outputs="image")
|
115 |
+
iface.launch()
|