tools / clear_mem.py
patrickvonplaten's picture
rename
27dfa17
raw
history blame contribute delete
665 Bytes
#!/usr/bin/env python3
import torch
import gc
from diffusers import DiffusionPipeline
shape = (30_000, 30_000)
input = torch.randn(shape, device="cuda")
def clear_memory(model):
model.to('cpu')
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.clear_autocast_cache()
for _ids in ["runwayml/stable-diffusion-v1-5", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"]:
pipe = DiffusionPipeline.from_pretrained(_ids, use_safetensors=True).to("cuda")
pipe("hey", num_inference_steps=1)
print("finished...")
clear_memory(pipe)