Ahsen Khaliq commited on
Commit
2b08e86
1 Parent(s): 28b55ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -61
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import os
2
- os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html")
3
  os.system("git clone https://github.com/openai/CLIP")
4
  os.system("pip install -e ./CLIP")
5
- os.system("pip install einops ninja scipy numpy Pillow tqdm imageio-ffmpeg imageio")
6
  import sys
7
  sys.path.append('./CLIP')
8
  import io
@@ -105,65 +103,71 @@ zs = torch.randn([10000, G.mapping.z_dim], device=device)
105
  w_stds = G.mapping(zs, None).std(0)
106
 
107
 
108
- def inference(text,steps,image):
109
- all_frames = []
110
- target = clip_model.embed_text(text)
111
- if image:
112
- target = embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
113
- else:
114
  target = clip_model.embed_text(text)
115
- steps = steps
116
- #seed = 2
117
- seed = -1
118
- if seed == -1:
119
- seed = np.random.randint(0,2**32 - 1)
120
- tf = Compose([
121
- Resize(224),
122
- lambda x: torch.clamp((x+1)/2,min=0,max=1),
123
- ])
124
- torch.manual_seed(seed)
125
- timestring = time.strftime('%Y%m%d%H%M%S')
126
- with torch.no_grad():
127
- qs = []
128
- losses = []
129
- for _ in range(8):
130
- q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
131
- images = G.synthesis(q * w_stds + G.mapping.w_avg)
132
- embeds = embed_image(images.add(1).div(2))
133
- loss = spherical_dist_loss(embeds, target).mean(0)
134
- i = torch.argmin(loss)
135
- qs.append(q[i])
136
- losses.append(loss[i])
137
- qs = torch.stack(qs)
138
- losses = torch.stack(losses)
139
- print(losses)
140
- print(losses.shape, qs.shape)
141
- i = torch.argmin(losses)
142
- q = qs[i].unsqueeze(0)
143
- q.requires_grad_()
144
- q_ema = q
145
- opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
146
- loop = tqdm(range(steps))
147
- for i in loop:
148
- opt.zero_grad()
149
- w = q * w_stds
150
- image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
151
- embed = embed_image(image.add(1).div(2))
152
- loss = spherical_dist_loss(embed, target).mean()
153
- loss.backward()
154
- opt.step()
155
- loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
156
- q_ema = q_ema * 0.9 + q * 0.1
157
- image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')
158
- pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
159
- all_frames.append(pil_image)
160
- #os.makedirs(f'samples/{timestring}', exist_ok=True)
161
- #pil_image.save(f'samples/{timestring}/{i:04}.jpg')
162
- writer = imageio.get_writer('test.mp4', fps=15)
163
- for im in all_frames:
164
- writer.append_data(np.array(im))
165
- writer.close()
166
- return pil_image, "test.mp4"
 
 
 
 
 
 
 
 
 
167
 
168
 
169
  title = "StyleGAN3+CLIP"
@@ -172,7 +176,7 @@ article = "<p style='text-align: center'><a href='https://colab.research.google.
172
  examples = [['mario',150,None]]
173
  gr.Interface(
174
  inference,
175
- ["text",gr.inputs.Slider(minimum=50, maximum=200, step=1, default=150, label="steps"),gr.inputs.Image(type="pil", label="Image (Optional)", optional=True)],
176
  [gr.outputs.Image(type="pil", label="Output"),"playable_video"],
177
  title=title,
178
  description=description,
1
  import os
 
2
  os.system("git clone https://github.com/openai/CLIP")
3
  os.system("pip install -e ./CLIP")
 
4
  import sys
5
  sys.path.append('./CLIP')
6
  import io
103
  w_stds = G.mapping(zs, None).std(0)
104
 
105
 
106
+ def inference(text,steps,image,mode):
107
+ if mode == "CLIP+StyleGAN3":
108
+ all_frames = []
 
 
 
109
  target = clip_model.embed_text(text)
110
+ if image:
111
+ target = embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
112
+ else:
113
+ target = clip_model.embed_text(text)
114
+ steps = steps
115
+ #seed = 2
116
+ seed = -1
117
+ if seed == -1:
118
+ seed = np.random.randint(0,2**32 - 1)
119
+ tf = Compose([
120
+ Resize(224),
121
+ lambda x: torch.clamp((x+1)/2,min=0,max=1),
122
+ ])
123
+ torch.manual_seed(seed)
124
+ timestring = time.strftime('%Y%m%d%H%M%S')
125
+ with torch.no_grad():
126
+ qs = []
127
+ losses = []
128
+ for _ in range(8):
129
+ q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
130
+ images = G.synthesis(q * w_stds + G.mapping.w_avg)
131
+ embeds = embed_image(images.add(1).div(2))
132
+ loss = spherical_dist_loss(embeds, target).mean(0)
133
+ i = torch.argmin(loss)
134
+ qs.append(q[i])
135
+ losses.append(loss[i])
136
+ qs = torch.stack(qs)
137
+ losses = torch.stack(losses)
138
+ print(losses)
139
+ print(losses.shape, qs.shape)
140
+ i = torch.argmin(losses)
141
+ q = qs[i].unsqueeze(0)
142
+ q.requires_grad_()
143
+ q_ema = q
144
+ opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
145
+ loop = tqdm(range(steps))
146
+ for i in loop:
147
+ opt.zero_grad()
148
+ w = q * w_stds
149
+ image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
150
+ embed = embed_image(image.add(1).div(2))
151
+ loss = spherical_dist_loss(embed, target).mean()
152
+ loss.backward()
153
+ opt.step()
154
+ loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
155
+ q_ema = q_ema * 0.9 + q * 0.1
156
+ image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')
157
+ pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
158
+ all_frames.append(pil_image)
159
+ #os.makedirs(f'samples/{timestring}', exist_ok=True)
160
+ #pil_image.save(f'samples/{timestring}/{i:04}.jpg')
161
+ writer = imageio.get_writer('test.mp4', fps=15)
162
+ for im in all_frames:
163
+ writer.append_data(np.array(im))
164
+ writer.close()
165
+ return pil_image, "test.mp4"
166
+ else:
167
+ os.system("python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \
168
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl")
169
+ img = Image.new("RGB", (800, 1280), (255, 255, 255))
170
+ return img, "lerp.mp4"
171
 
172
 
173
  title = "StyleGAN3+CLIP"
176
  examples = [['mario',150,None]]
177
  gr.Interface(
178
  inference,
179
+ ["text",gr.inputs.Slider(minimum=50, maximum=200, step=1, default=150, label="steps"),gr.inputs.Image(type="pil", label="Image (Optional)", optional=True),gradio.inputs.Radio(choices["CLIP+StyleGAN3","Stylegan3 interpolation"] type="value", default="CLIP+StyleGAN3", label="mode")],
180
  [gr.outputs.Image(type="pil", label="Output"),"playable_video"],
181
  title=title,
182
  description=description,