jcmc commited on
Commit
d61193c
1 Parent(s): 9c80c48

Adjusting inference params

Browse files
Files changed (2) hide show
  1. app.py +31 -27
  2. requirements.txt +2 -1
app.py CHANGED
@@ -34,32 +34,14 @@ from utils.sr_utils import *
34
 
35
  device = torch.device('cuda')
36
 
37
- # torch.hub.download_url_to_file('https://images.pexels.com/photos/68767/divers-underwater-ocean-swim-68767.jpeg', 'coralreef.jpeg')
38
-
39
- # def fetch(url_or_path):
40
- # if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
41
- # r = requests.get(url_or_path)
42
- # r.raise_for_status()
43
- # fd = io.BytesIO()
44
- # fd.write(r.content)
45
- # fd.seek(0)
46
- # return fd
47
- # return open(url_or_path, 'rb')
48
-
49
- # def parse_prompt(prompt):
50
- # if prompt.startswith('http://') or prompt.startswith('https://'):
51
- # vals = prompt.rsplit(':', 2)
52
- # vals = [vals[0] + ':' + vals[1], *vals[2:]]
53
- # else:
54
- # vals = prompt.rsplit(':', 1)
55
- # vals = vals + ['', '1'][len(vals):]
56
- # return vals[0], float(vals[1])
57
  clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False)
58
  clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False)
59
  clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16}
60
 
61
  clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
62
-
 
 
63
 
64
  class MakeCutouts(torch.nn.Module):
65
  def __init__(self, cut_size, cutn):
@@ -149,7 +131,18 @@ class CLIPActivationLoss(nn.Module):
149
  return -loss if self.maximize else loss
150
 
151
 
152
- def optimize_network(seed, num_iterations, optimizer_type, lr):
 
 
 
 
 
 
 
 
 
 
 
153
  global itt
154
  itt = 0
155
 
@@ -183,9 +176,9 @@ def optimize_network(seed, num_iterations, optimizer_type, lr):
183
  # Initialize input noise
184
  net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach()
185
 
186
- if optimizer_type == 'Adam':
187
  optimizer = torch.optim.Adam(net.parameters(), lr)
188
- elif optimizer_type == 'MADGRAD':
189
  optimizer = MADGRAD(net.parameters(), lr, momentum=0.9)
190
  scaler = torch.cuda.amp.GradScaler()
191
 
@@ -234,7 +227,6 @@ def inference(
234
  lr,
235
  num_iterations,
236
  cutn,
237
- clip_model,
238
  layer,
239
  neuron,
240
  class_token,
@@ -249,14 +241,26 @@ def inference(
249
  # Begin optimization / generation
250
  gc.collect()
251
  torch.cuda.empty_cache()
252
- out = optimize_network(seed, num_iterations, opt_type, lr)
 
 
 
 
 
 
 
 
 
 
 
 
253
  out.save(f'dip_{timestring}.png', quality=100)
254
  if save_progress_video:
255
  video_writer.close()
256
  return out
257
 
258
  iface = gr.Interface(fn=inference,
259
- inputs=["number", "text", "number", "number", "number", "text", "number", "number",
260
  gr.inputs.Checkbox(default=False, label="class_token"),
261
  gr.inputs.Checkbox(default=True, label="maximise"),
262
  "number"],
 
34
 
35
  device = torch.device('cuda')
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False)
38
  clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False)
39
  clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16}
40
 
41
  clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
42
+ clip_model = 'ViT-B/16'
43
+ sideX, sideY = 512, 512 # Resolution
44
+ inv_color_scale = 1.6
45
 
46
  class MakeCutouts(torch.nn.Module):
47
  def __init__(self, cut_size, cutn):
 
131
  return -loss if self.maximize else loss
132
 
133
 
134
+ def optimize_network(
135
+ seed,
136
+ opt_type,
137
+ lr,
138
+ num_iterations,
139
+ cutn,
140
+ layer,
141
+ neuron,
142
+ class_token,
143
+ maximize,
144
+ display_rate = 20
145
+ ):
146
  global itt
147
  itt = 0
148
 
 
176
  # Initialize input noise
177
  net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach()
178
 
179
+ if opt_type == 'Adam':
180
  optimizer = torch.optim.Adam(net.parameters(), lr)
181
+ elif opt_type == 'MADGRAD':
182
  optimizer = MADGRAD(net.parameters(), lr, momentum=0.9)
183
  scaler = torch.cuda.amp.GradScaler()
184
 
 
227
  lr,
228
  num_iterations,
229
  cutn,
 
230
  layer,
231
  neuron,
232
  class_token,
 
241
  # Begin optimization / generation
242
  gc.collect()
243
  torch.cuda.empty_cache()
244
+ out = optimize_network(
245
+ seed,
246
+ opt_type,
247
+ lr,
248
+ num_iterations,
249
+ cutn,
250
+ layer,
251
+ neuron,
252
+ class_token,
253
+ maximize,
254
+ display_rate
255
+ )
256
+
257
  out.save(f'dip_{timestring}.png', quality=100)
258
  if save_progress_video:
259
  video_writer.close()
260
  return out
261
 
262
  iface = gr.Interface(fn=inference,
263
+ inputs=["number", "text", "number", "number", "number", "number", "number",
264
  gr.inputs.Checkbox(default=False, label="class_token"),
265
  gr.inputs.Checkbox(default=True, label="maximise"),
266
  "number"],
requirements.txt CHANGED
@@ -8,4 +8,5 @@ lpips
8
  numpy
9
  imageio
10
  einops
11
- madgrad
 
 
8
  numpy
9
  imageio
10
  einops
11
+ madgrad
12
+ cv2