import torch, os, gradio as gr, numpy as np from torchvision import utils, transforms from progan_modules import Generator CHECKPOINT_DIR = "./model" Z_DIM, CHANNEL_SIZE = 128, 128 DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") FIXED_STEP = 6 FIXED_ALPHA = 0.0 g_running = Generator(CHANNEL_SIZE, Z_DIM, pixel_norm=False, tanh=False).to(DEVICE) g_running.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "g.model"), map_location=DEVICE)) g_running.eval() to_pil = transforms.ToPILImage() @torch.inference_mode() def sample_images(n_images: int = 50, seed: int | None = None): if seed is not None and seed >= 0: torch.manual_seed(seed); np.random.seed(seed) else: torch.seed() z = torch.randn(n_images, Z_DIM, device=DEVICE) imgs = g_running(z, step=FIXED_STEP, alpha=FIXED_ALPHA).cpu() grid = utils.make_grid(imgs, nrow=10, normalize=True, value_range=(-1, 1)) return to_pil(grid) demo = gr.Interface( fn=sample_images, inputs=[ gr.Slider(1, 200, value=50, step=10, label="Jumlah Gambar (kelipatan 10)"), gr.Number(value=-1, precision=0, label="Seed (‑1 = acak)"), ], outputs=gr.Image(type="pil", label="Grid Hasil"), title="Progressive Growing Generative Adversarial Network", description="contoh implementasi PGGAN untuk dataset jerawat", allow_flagging="never", ) if __name__ == "__main__": demo.queue() demo.launch(show_api=False, share=True)