tools / clear_mem.py
patrickvonplaten's picture
up
e402ae6
raw
history blame
No virus
397 Bytes
#!/usr/bin/env python3
import torch
import gc
shape = (10,000)
input = torch.ones((shape, shape), device="cuda")
def clear_memory(model):
model.to('cpu')
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.clear_autocast_cache()
for _ in range(6):
linear = torch.nn.Linear(shape, shape).to("cuda")
output = linear(input)
clear_memory(linear)