Ahsen Khaliq commited on
Commit
408fc5d
β€’
1 Parent(s): 821b58d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -6,7 +6,7 @@ import math
6
  from pathlib import Path
7
  import sys
8
  sys.path.insert(1, './taming-transformers')
9
- #from IPython import display
10
  from base64 import b64encode
11
  from omegaconf import OmegaConf
12
  from PIL import Image
@@ -25,6 +25,7 @@ from PIL import ImageFile, Image
25
  ImageFile.LOAD_TRUNCATED_IMAGES = True
26
  import gradio as gr
27
  torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg')
 
28
  def sinc(x):
29
  return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
30
  def lanczos(x, a):
@@ -189,7 +190,8 @@ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
189
  print('Using device:', device)
190
  model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
191
  perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
192
- def inference(text, seed, step_size, max_iterations):
 
193
  texts = text
194
  target_images = ""
195
  max_iterations = max_iterations
@@ -221,7 +223,7 @@ def inference(text, seed, step_size, max_iterations):
221
  cut_size = perceptor.visual.input_resolution
222
  f = 2**(model.decoder.num_resolutions - 1)
223
  make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
224
- toksX, toksY = args.size[0] // f, args.size[1] // f
225
  sideX, sideY = toksX * f, toksY * f
226
  if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
227
  e_dim = 256
@@ -237,11 +239,11 @@ def inference(text, seed, step_size, max_iterations):
237
  # z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
238
  # normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],
239
  # std=[0.229, 0.224, 0.225])
240
- if args.init_image:
241
- if 'http' in args.init_image:
242
- img = Image.open(urlopen(args.init_image))
243
  else:
244
- img = Image.open(args.init_image)
245
  pil_image = img.convert('RGB')
246
  pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
247
  pil_tensor = TF.to_tensor(pil_image)
@@ -288,8 +290,8 @@ def inference(text, seed, step_size, max_iterations):
288
  losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
289
  tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
290
  out = synth(z)
291
- #TF.to_pil_image(out[0].cpu()).save('progress.png')
292
- #display.display(display.Image('progress.png'))
293
  def ascend_txt():
294
  # global i
295
  out = synth(z)
@@ -335,7 +337,7 @@ def load_image( infilename ) :
335
  img.load()
336
  data = np.asarray( img, dtype="int32" )
337
  return data
338
- def throttled_inference(text, seed, step_size, max_iterations):
339
  global inferences_running
340
  current = inferences_running
341
  if current >= 2:
@@ -344,7 +346,7 @@ def throttled_inference(text, seed, step_size, max_iterations):
344
  print(f"Inference starting when we already had {current} running")
345
  inferences_running += 1
346
  try:
347
- return inference(text, seed, step_size, max_iterations)
348
  finally:
349
  print("Inference finished")
350
  inferences_running -= 1
@@ -357,14 +359,16 @@ gr.Interface(
357
  gr.inputs.Number(default=42, label="seed"),
358
  gr.inputs.Slider(minimum=0.1, maximum=0.9, default=0.23, label='step size'),
359
  gr.inputs.Slider(minimum=100, maximum=150, default=100, label='max iterations', step=1),
 
 
360
  ],
361
  gr.outputs.Image(type="numpy", label="Output"),
362
  title=title,
363
  description=description,
364
  article=article,
365
  examples=[
366
- ['a garden by james gurney',42,0.23, 100],
367
- ['coral reef city artstationHQ',1000,0.6, 110],
368
- ['a cabin in the mountains unreal engine',98,0.3, 120]
369
  ]
370
  ).launch(debug=True)
 
6
  from pathlib import Path
7
  import sys
8
  sys.path.insert(1, './taming-transformers')
9
+ # from IPython import display
10
  from base64 import b64encode
11
  from omegaconf import OmegaConf
12
  from PIL import Image
 
25
  ImageFile.LOAD_TRUNCATED_IMAGES = True
26
  import gradio as gr
27
  torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg')
28
+
29
  def sinc(x):
30
  return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
31
  def lanczos(x, a):
 
190
  print('Using device:', device)
191
  model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
192
  perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
193
+ def inference(text, seed, step_size, max_iterations, width, height):
194
+ size=[width, height]
195
  texts = text
196
  target_images = ""
197
  max_iterations = max_iterations
 
223
  cut_size = perceptor.visual.input_resolution
224
  f = 2**(model.decoder.num_resolutions - 1)
225
  make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
226
+ toksX, toksY = size[0] // f, size[1] // f
227
  sideX, sideY = toksX * f, toksY * f
228
  if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
229
  e_dim = 256
 
239
  # z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
240
  # normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],
241
  # std=[0.229, 0.224, 0.225])
242
+ if init_image:
243
+ if 'http' in init_image:
244
+ img = Image.open(urlopen(init_image))
245
  else:
246
+ img = Image.open(init_image)
247
  pil_image = img.convert('RGB')
248
  pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
249
  pil_tensor = TF.to_tensor(pil_image)
 
290
  losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
291
  tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
292
  out = synth(z)
293
+ # TF.to_pil_image(out[0].cpu()).save('progress.png')
294
+ # display.display(display.Image('progress.png'))
295
  def ascend_txt():
296
  # global i
297
  out = synth(z)
 
337
  img.load()
338
  data = np.asarray( img, dtype="int32" )
339
  return data
340
+ def throttled_inference(text, seed, step_size, max_iterations, width, height):
341
  global inferences_running
342
  current = inferences_running
343
  if current >= 2:
 
346
  print(f"Inference starting when we already had {current} running")
347
  inferences_running += 1
348
  try:
349
+ return inference(text, seed, step_size, max_iterations, width, height)
350
  finally:
351
  print("Inference finished")
352
  inferences_running -= 1
 
359
  gr.inputs.Number(default=42, label="seed"),
360
  gr.inputs.Slider(minimum=0.1, maximum=0.9, default=0.23, label='step size'),
361
  gr.inputs.Slider(minimum=100, maximum=150, default=100, label='max iterations', step=1),
362
+ gr.inputs.Slider(minimum=200, maximum=280, default=256, label='width', step=1),
363
+ gr.inputs.Slider(minimum=200, maximum=280, default=256, label='height', step=1),
364
  ],
365
  gr.outputs.Image(type="numpy", label="Output"),
366
  title=title,
367
  description=description,
368
  article=article,
369
  examples=[
370
+ ['a garden by james gurney',42,0.16, 100, 256, 256],
371
+ ['coral reef city artstationHQ',1000,0.6, 110, 200, 200],
372
+ ['a cabin in the mountains unreal engine',98,0.3, 120, 280, 280]
373
  ]
374
  ).launch(debug=True)