#!/usr/bin/env python3 import torch import gc shape = (10,000) input = torch.ones((shape, shape), device="cuda") def clear_memory(model): model.to('cpu') gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() torch.clear_autocast_cache() for _ in range(6): linear = torch.nn.Linear(shape, shape).to("cuda") output = linear(input) clear_memory(linear)