Update app.py
Browse files
app.py
CHANGED
|
@@ -46,10 +46,11 @@ class ModelWrapper:
|
|
| 46 |
self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
| 48 |
self.num_step = num_step
|
| 49 |
-
|
|
|
|
| 50 |
def create_generator(self, model_id, checkpoint_path):
|
| 51 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
| 52 |
-
state_dict = torch.load(checkpoint_path, map_location="
|
| 53 |
generator.load_state_dict(state_dict, strict=True)
|
| 54 |
generator.requires_grad_(False)
|
| 55 |
return generator
|
|
@@ -108,7 +109,7 @@ class ModelWrapper:
|
|
| 108 |
eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
|
| 109 |
return eval_images
|
| 110 |
|
| 111 |
-
|
| 112 |
@torch.no_grad()
|
| 113 |
def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
|
| 114 |
print("Running model inference...")
|
|
@@ -196,9 +197,6 @@ def create_demo():
|
|
| 196 |
num_step = 4
|
| 197 |
revision = None
|
| 198 |
|
| 199 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 200 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 201 |
-
|
| 202 |
accelerator = Accelerator()
|
| 203 |
|
| 204 |
model = ModelWrapper(model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator)
|
|
@@ -211,10 +209,10 @@ def create_demo():
|
|
| 211 |
run_button = gr.Button("Run")
|
| 212 |
with gr.Accordion(label="Advanced options", open=True):
|
| 213 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=1000000, step=1, value=0)
|
| 214 |
-
num_images = gr.Slider(label="Number of generated images", minimum=1, maximum=16, step=1, value=
|
| 215 |
fast_vae_decode = gr.Checkbox(label="Use Tiny VAE for faster decoding", value=True)
|
| 216 |
-
height = gr.Slider(label="Image Height", minimum=512, maximum=1536, step=64, value=
|
| 217 |
-
width = gr.Slider(label="Image Width", minimum=512, maximum=1536, step=64, value=
|
| 218 |
with gr.Column():
|
| 219 |
result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)
|
| 220 |
error_message = gr.Text(label="Job Status")
|
|
|
|
| 46 |
self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
| 48 |
self.num_step = num_step
|
| 49 |
+
|
| 50 |
+
@spaces.GPU(enable_queue=True)
|
| 51 |
def create_generator(self, model_id, checkpoint_path):
|
| 52 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
| 53 |
+
state_dict = torch.load(checkpoint_path, map_location="cuda")
|
| 54 |
generator.load_state_dict(state_dict, strict=True)
|
| 55 |
generator.requires_grad_(False)
|
| 56 |
return generator
|
|
|
|
| 109 |
eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
|
| 110 |
return eval_images
|
| 111 |
|
| 112 |
+
|
| 113 |
@torch.no_grad()
|
| 114 |
def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
|
| 115 |
print("Running model inference...")
|
|
|
|
| 197 |
num_step = 4
|
| 198 |
revision = None
|
| 199 |
|
|
|
|
|
|
|
|
|
|
| 200 |
accelerator = Accelerator()
|
| 201 |
|
| 202 |
model = ModelWrapper(model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator)
|
|
|
|
| 209 |
run_button = gr.Button("Run")
|
| 210 |
with gr.Accordion(label="Advanced options", open=True):
|
| 211 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=1000000, step=1, value=0)
|
| 212 |
+
num_images = gr.Slider(label="Number of generated images", minimum=1, maximum=16, step=1, value=1)
|
| 213 |
fast_vae_decode = gr.Checkbox(label="Use Tiny VAE for faster decoding", value=True)
|
| 214 |
+
height = gr.Slider(label="Image Height", minimum=512, maximum=1536, step=64, value=512)
|
| 215 |
+
width = gr.Slider(label="Image Width", minimum=512, maximum=1536, step=64, value=512)
|
| 216 |
with gr.Column():
|
| 217 |
result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)
|
| 218 |
error_message = gr.Text(label="Job Status")
|