File size: 5,024 Bytes
b585c7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gc
import datetime
import inspect
import torch
import numpy as np
dtype_memory_size_dict = {
torch.float64: 64 / 8,
torch.double: 64 / 8,
torch.float32: 32 / 8,
torch.float: 32 / 8,
torch.float16: 16 / 8,
torch.half: 16 / 8,
torch.int64: 64 / 8,
torch.long: 64 / 8,
torch.int32: 32 / 8,
torch.int: 32 / 8,
torch.int16: 16 / 8,
torch.short: 16 / 6,
torch.uint8: 8 / 8,
torch.int8: 8 / 8,
}
# compatibility of torch1.0
if getattr(torch, "bfloat16", None) is not None:
dtype_memory_size_dict[torch.bfloat16] = 16 / 8
if getattr(torch, "bool", None) is not None:
dtype_memory_size_dict[
torch.bool] = 8 / 8 # pytorch use 1 byte for a bool, see https://github.com/pytorch/pytorch/issues/41571
def get_mem_space(x):
try:
ret = dtype_memory_size_dict[x]
except KeyError:
print(f"dtype {x} is not supported!")
return ret
import contextlib, sys
@contextlib.contextmanager
def file_writer(file_name = None):
# Create writer object based on file_name
writer = open(file_name, "aw") if file_name is not None else sys.stdout
# yield the writer object for the actual use
yield writer
# If it is file, then close the writer object
if file_name != None: writer.close()
class MemTracker(object):
"""
Class used to track pytorch memory usage
Arguments:
detail(bool, default True): whether the function shows the detail gpu memory usage
path(str): where to save log file
verbose(bool, default False): whether show the trivial exception
device(int): GPU number, default is 0
"""
def __init__(self, detail=True, path='', verbose=False, device=0, log_to_disk=False):
self.print_detail = detail
self.last_tensor_sizes = set()
self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt'
self.verbose = verbose
self.begin = True
self.device = device
self.log_to_disk = log_to_disk
def get_tensors(self):
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
tensor = obj
else:
continue
if tensor.is_cuda:
yield tensor
except Exception as e:
if self.verbose:
print('A trivial exception occurred: {}'.format(e))
def get_tensor_usage(self):
sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()]
return np.sum(sizes) / 1024 ** 2
def get_allocate_usage(self):
return torch.cuda.memory_allocated() / 1024 ** 2
def clear_cache(self):
gc.collect()
torch.cuda.empty_cache()
def print_all_gpu_tensor(self, file=None):
for x in self.get_tensors():
print(x.size(), x.dtype, np.prod(np.array(x.size())) * get_mem_space(x.dtype) / 1024 ** 2, file=file)
def track(self):
"""
Track the GPU memory usage
"""
frameinfo = inspect.stack()[1]
where_str = frameinfo.filename + ' line ' + str(frameinfo.lineno) + ': ' + frameinfo.function
if self.log_to_disk:
file_name = self.gpu_profile_fn
else:
file_name = None
with file_writer(file_name) as f:
if self.begin:
f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |"
f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
self.begin = False
if self.print_detail is True:
ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()]
new_tensor_sizes = {(type(x),
tuple(x.size()),
ts_list.count((x.size(), x.dtype)),
np.prod(np.array(x.size())) * get_mem_space(x.dtype) / 1024 ** 2,
x.dtype) for x in self.get_tensors()}
for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes:
f.write(
f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m * n)[:6]} M | {str(t):<20} | {data_type}\n')
for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes:
f.write(
f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m * n)[:6]} M | {str(t):<20} | {data_type}\n')
self.last_tensor_sizes = new_tensor_sizes
f.write(f"\nAt {where_str:<50}"
f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
|