bala1802 commited on
Commit
46985f9
1 Parent(s): 5b3113b

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +2 -2
prediction.py CHANGED
@@ -10,9 +10,9 @@ def predict(prompt, pipe, loss_function=None):
10
  latents = []
11
 
12
  for seed_number, sd_concept in zip(config.SEEDS, config.STABLE_DIFUSION_CONCEPTS):
13
- torch.mps.empty_cache()
14
  gc.collect()
15
- torch.mps.empty_cache()
16
 
17
  prompt = [f'{prompt} {sd_concept}']
18
  latent = generator.generate_images(pipe=pipe, seed_number=seed_number, prompt=prompt, loss_function=loss_function)
 
10
  latents = []
11
 
12
  for seed_number, sd_concept in zip(config.SEEDS, config.STABLE_DIFUSION_CONCEPTS):
13
+ torch.cuda.empty_cache()
14
  gc.collect()
15
+ torch.cuda.empty_cache()
16
 
17
  prompt = [f'{prompt} {sd_concept}']
18
  latent = generator.generate_images(pipe=pipe, seed_number=seed_number, prompt=prompt, loss_function=loss_function)