zetavg
improve speed of switching models by offloading unused ones to cpu ram instead if unloading
6148b7c unverified
raw
history blame contribute delete
No virus
2.12 kB
from collections import OrderedDict
import gc
import torch
from ..lib.get_device import get_device
device_type = get_device()
class ModelLRUCache:
def __init__(self, capacity=5):
self.cache = OrderedDict()
self.capacity = capacity
def get(self, key):
if key in self.cache:
# Move the accessed item to the end of the OrderedDict
self.cache.move_to_end(key)
models_did_move = False
for k, m in self.cache.items():
if key != k and m.device.type != 'cpu':
models_did_move = True
self.cache[k] = m.to('cpu')
if models_did_move:
gc.collect()
# if not shared.args.cpu: # will not be running on CPUs anyway
with torch.no_grad():
torch.cuda.empty_cache()
model = self.cache[key]
if (model.device.type != device_type or
hasattr(model, "model") and
model.model.device.type != device_type):
model = model.to(device_type)
return model
return None
def set(self, key, value):
if key in self.cache:
# If the key already exists, update its value
self.cache[key] = value
else:
# If the cache has reached its capacity, remove the least recently used item
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = value
def clear(self):
self.cache.clear()
def prepare_to_set(self):
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
models_did_move = False
for k, m in self.cache.items():
if m.device.type != 'cpu':
models_did_move = True
self.cache[k] = m.to('cpu')
if models_did_move:
gc.collect()
# if not shared.args.cpu: # will not be running on CPUs anyway
with torch.no_grad():
torch.cuda.empty_cache()