Tony Lian commited on
Commit
1aae4d6
1 Parent(s): 4875c08

Empty GPU cache at the end of run

Browse files
Files changed (2) hide show
  1. baseline.py +4 -0
  2. generation.py +4 -0
baseline.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import models
5
  from models import pipelines
6
  from shared import model_dict, DEFAULT_OVERALL_NEGATIVE_PROMPT
 
7
 
8
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
9
 
@@ -36,5 +37,8 @@ def run(prompt, scheduler_key='dpm_scheduler', bg_seed=1, num_inference_steps=20
36
  model_dict, latents, input_embeddings, num_inference_steps,
37
  guidance_scale=guidance_scale, scheduler_key=scheduler_key
38
  )
 
 
 
39
 
40
  return images[0]
 
4
  import models
5
  from models import pipelines
6
  from shared import model_dict, DEFAULT_OVERALL_NEGATIVE_PROMPT
7
+ import gc
8
 
9
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
10
 
 
37
  model_dict, latents, input_embeddings, num_inference_steps,
38
  guidance_scale=guidance_scale, scheduler_key=scheduler_key
39
  )
40
+
41
+ gc.collect()
42
+ torch.cuda.empty_cache()
43
 
44
  return images[0]
generation.py CHANGED
@@ -7,6 +7,7 @@ import utils
7
  from models import pipelines, sam
8
  from utils import parse, latents
9
  from shared import model_dict, sam_model_dict, DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
 
10
 
11
  verbose = False
12
 
@@ -194,5 +195,8 @@ def run(
194
 
195
  # display(Image.fromarray(images[0]), "img", run_ind)
196
 
 
 
 
197
  return images[0], so_img_list
198
 
 
7
  from models import pipelines, sam
8
  from utils import parse, latents
9
  from shared import model_dict, sam_model_dict, DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
10
+ import gc
11
 
12
  verbose = False
13
 
 
195
 
196
  # display(Image.fromarray(images[0]), "img", run_ind)
197
 
198
+ gc.collect()
199
+ torch.cuda.empty_cache()
200
+
201
  return images[0], so_img_list
202