AdritRao's picture
Upload 62 files
a3290d1
raw
history blame contribute delete
No virus
2.72 kB
import subprocess
from keras import Model
# from keras.utils import multi_gpu_model
# from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
def get_available_gpus(num_gpus: int = None):
"""Get gpu ids for gpus that are >95% free.
Tensorflow does not support checking free memory on gpus.
This is a crude method that relies on `nvidia-smi` to
determine which gpus are occupied and which are free.
Args:
num_gpus: Number of requested gpus. If not specified,
ids of all available gpu(s) are returned.
Returns:
List[int]: List of gpu ids that are free. Length
will equal `num_gpus`, if specified.
"""
# Built-in tensorflow gpu id.
assert isinstance(num_gpus, (type(None), int))
if num_gpus == 0:
return [-1]
num_requested_gpus = num_gpus
try:
num_gpus = (
len(
subprocess.check_output("nvidia-smi --list-gpus", shell=True)
.decode()
.split("\n")
)
- 1
)
out_str = subprocess.check_output("nvidia-smi | grep MiB", shell=True).decode()
except subprocess.CalledProcessError:
return None
mem_str = [x for x in out_str.split() if "MiB" in x]
# First 2 * num_gpu elements correspond to memory for gpus
# Order: (occupied-0, total-0, occupied-1, total-1, ...)
mems = [float(x[:-3]) for x in mem_str]
gpu_percent_occupied_mem = [
mems[2 * gpu_id] / mems[2 * gpu_id + 1] for gpu_id in range(num_gpus)
]
available_gpus = [
gpu_id for gpu_id, mem in enumerate(gpu_percent_occupied_mem) if mem < 0.05
]
if num_requested_gpus and num_requested_gpus > len(available_gpus):
raise ValueError(
"Requested {} gpus, only {} are free".format(
num_requested_gpus, len(available_gpus)
)
)
return available_gpus[:num_requested_gpus] if num_requested_gpus else available_gpus
class ModelMGPU(Model):
"""Wrapper for distributing model across multiple gpus"""
def __init__(self, ser_model, gpus):
pmodel = multi_gpu_model(ser_model, gpus) # noqa: F821
self.__dict__.update(pmodel.__dict__)
self._smodel = ser_model
def __getattribute__(self, attrname):
"""Override load and save methods to be used from the serial-model. The
serial-model holds references to the weights in the multi-gpu model.
"""
# return Model.__getattribute__(self, attrname)
if "load" in attrname or "save" in attrname:
return getattr(self._smodel, attrname)
return super(ModelMGPU, self).__getattribute__(attrname)