johnowhitaker commited on
Commit
fee9d2e
·
1 Parent(s): 32a5a14

Added light and dark models

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -465,12 +465,26 @@ class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin)
465
  gan_new.__class__ = GeneratorWithPyTorchModelHubMixin
466
  gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32])
467
 
468
- def gen_ims(n_rows):
469
- ims = gan_new(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.)
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  grid = torchvision.utils.make_grid(ims, nrow=int(n_rows)).permute(1, 2, 0).detach().cpu().numpy()
471
  return (grid*255).astype(np.uint8)
472
  iface = gr.Interface(fn=gen_ims,
473
- inputs=[gr.inputs.Slider(minimum=1, maximum=6, step=1, default=3,label="N rows")],
 
474
  outputs=[gr.outputs.Image(type="numpy", label="Generated Images")],
475
  title='Demo for https://huggingface.co/johnowhitaker/orbgan_e1'
476
  )
 
465
  gan_new.__class__ = GeneratorWithPyTorchModelHubMixin
466
  gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32])
467
 
468
+ gan_light = Generator(latent_dim=256, image_size=256, attn_res_layers = [32])
469
+ gan_light.__class__ = GeneratorWithPyTorchModelHubMixin
470
+ gan_light = gan_light.from_pretrained('johnowhitaker/orbgan_ligt', latent_dim=256, image_size=256, attn_res_layers = [32])
471
+
472
+ gan_dark = Generator(latent_dim=256, image_size=256, attn_res_layers = [32])
473
+ gan_dark.__class__ = GeneratorWithPyTorchModelHubMixin
474
+ gan_dark = gan_dark.from_pretrained('johnowhitaker/orbgan_ligt', latent_dim=256, image_size=256, attn_res_layers = [32])
475
+
476
+ def gen_ims(n_rows, model='both'):
477
+ if model == "both":
478
+ ims = gan_new(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.)
479
+ if model == "light":
480
+ ims = gan_light(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.)
481
+ if model == "dark":
482
+ ims = gan_dark(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.)
483
  grid = torchvision.utils.make_grid(ims, nrow=int(n_rows)).permute(1, 2, 0).detach().cpu().numpy()
484
  return (grid*255).astype(np.uint8)
485
  iface = gr.Interface(fn=gen_ims,
486
+ inputs=[gr.inputs.Slider(minimum=1, maximum=6, step=1, default=3,label="N rows"),
487
+ gradio.inputs.Dropdown(["both", "light", "dark"], type="value", default="dark", label="Orb Type (model)", optional=False)],
488
  outputs=[gr.outputs.Image(type="numpy", label="Generated Images")],
489
  title='Demo for https://huggingface.co/johnowhitaker/orbgan_e1'
490
  )