cocktailpeanut commited on
Commit
e95b42e
·
1 Parent(s): d471634
Files changed (2) hide show
  1. app.py +18 -9
  2. requirements.txt +2 -2
app.py CHANGED
@@ -12,8 +12,9 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import gradio as gr
16
- import spaces
17
 
18
  import argparse
19
  import inspect
@@ -1196,7 +1197,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1196
  if use_md_prompt or save_attention_map:
1197
  self.recover_attention_control(ori_attn_processors=ori_attn_processors) # recover attention controller
1198
  del self.controller
1199
- torch.cuda.empty_cache()
 
1200
  else:
1201
  print("### Encoding Real Image ###")
1202
  latents = self.vae.encode(image_lr)
@@ -1206,7 +1208,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1206
  anchor_std = latents.std()
1207
  if self.lowvram:
1208
  latents = latents.cpu()
1209
- torch.cuda.empty_cache()
 
1210
  if not output_type == "latent":
1211
  # make sure the VAE is in float32 mode, as it overflows in float16
1212
  needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
@@ -1227,7 +1230,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1227
  # cast back to fp16 if needed
1228
  if needs_upcasting:
1229
  self.vae.to(dtype=torch.float16)
1230
- torch.cuda.empty_cache()
 
1231
 
1232
  image = self.image_processor.postprocess(image, output_type=output_type)
1233
  if not os.path.exists(f'{result_path}'):
@@ -1250,7 +1254,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1250
  if self.lowvram:
1251
  latents = latents.to(device)
1252
  self.unet.to(device)
1253
- torch.cuda.empty_cache()
 
1254
 
1255
  current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
1256
  current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
@@ -1549,7 +1554,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1549
  latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean
1550
  if self.lowvram:
1551
  latents = latents.cpu()
1552
- torch.cuda.empty_cache()
 
1553
  if not output_type == "latent":
1554
  # make sure the VAE is in float32 mode, as it overflows in float16
1555
  needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
@@ -1620,12 +1626,14 @@ if __name__ == "__main__":
1620
 
1621
  args = parser.parse_args()
1622
 
1623
- pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
 
 
1624
 
1625
 
1626
  # GRADIO MODE
1627
 
1628
- @spaces.GPU()
1629
  def infer(prompt, resolution, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
1630
  set_seed(seed)
1631
  width,height = list(map(int, resolution.split(',')))
@@ -1634,7 +1642,8 @@ if __name__ == "__main__":
1634
  "n_cross_replace": {"default_": 1.0, "confetti": 0.8},
1635
  }
1636
  seed = seed
1637
- generator = torch.Generator(device='cuda')
 
1638
  generator = generator.manual_seed(seed)
1639
 
1640
  print(f"Prompt: {prompt}")
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import devicetorch
16
  import gradio as gr
17
+ #import spaces
18
 
19
  import argparse
20
  import inspect
 
1197
  if use_md_prompt or save_attention_map:
1198
  self.recover_attention_control(ori_attn_processors=ori_attn_processors) # recover attention controller
1199
  del self.controller
1200
+ devicetorch.empty_cache(torch)
1201
+ #torch.cuda.empty_cache()
1202
  else:
1203
  print("### Encoding Real Image ###")
1204
  latents = self.vae.encode(image_lr)
 
1208
  anchor_std = latents.std()
1209
  if self.lowvram:
1210
  latents = latents.cpu()
1211
+ #torch.cuda.empty_cache()
1212
+ devicetorch.empty_cache(torch)
1213
  if not output_type == "latent":
1214
  # make sure the VAE is in float32 mode, as it overflows in float16
1215
  needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
 
1230
  # cast back to fp16 if needed
1231
  if needs_upcasting:
1232
  self.vae.to(dtype=torch.float16)
1233
+ #torch.cuda.empty_cache()
1234
+ devicetorch.empty_cache(torch)
1235
 
1236
  image = self.image_processor.postprocess(image, output_type=output_type)
1237
  if not os.path.exists(f'{result_path}'):
 
1254
  if self.lowvram:
1255
  latents = latents.to(device)
1256
  self.unet.to(device)
1257
+ #torch.cuda.empty_cache()
1258
+ devicetorch.empty_cache(torch)
1259
 
1260
  current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
1261
  current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
 
1554
  latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean
1555
  if self.lowvram:
1556
  latents = latents.cpu()
1557
+ #torch.cuda.empty_cache()
1558
+ devicetorch.empty_cache(torch)
1559
  if not output_type == "latent":
1560
  # make sure the VAE is in float32 mode, as it overflows in float16
1561
  needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
 
1626
 
1627
  args = parser.parse_args()
1628
 
1629
+ #pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
1630
+ device = devicetorch.get(torch)
1631
+ pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to(device)
1632
 
1633
 
1634
  # GRADIO MODE
1635
 
1636
+ # @spaces.GPU()
1637
  def infer(prompt, resolution, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
1638
  set_seed(seed)
1639
  width,height = list(map(int, resolution.split(',')))
 
1642
  "n_cross_replace": {"default_": 1.0, "confetti": 0.8},
1643
  }
1644
  seed = seed
1645
+ #generator = torch.Generator(device='cuda')
1646
+ generator = torch.Generator(device=device)
1647
  generator = generator.manual_seed(seed)
1648
 
1649
  print(f"Prompt: {prompt}")
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  diffusers~=0.21.4
2
- torch~=2.1.0
3
  scipy~=1.11.3
4
  omegaconf~=2.3.0
5
  accelerate~=0.23.0
@@ -10,4 +10,4 @@ matplotlib
10
  gradio
11
  gradio_imageslider
12
  opencv-python
13
- torchvision
 
1
  diffusers~=0.21.4
2
+ #torch~=2.1.0
3
  scipy~=1.11.3
4
  omegaconf~=2.3.0
5
  accelerate~=0.23.0
 
10
  gradio
11
  gradio_imageslider
12
  opencv-python
13
+ #torchvision