File size: 2,425 Bytes
cc9dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
" Memory profiling callbacks "

import tracemalloc, threading, torch, time
from ..utils.mem import *
from ..basic_train import *
from ..torch_core import *
from ..utils.pynvml_gate import *

if use_gpu: pynvml = load_pynvml_env()

class PeakMemMetric(LearnerCallback):
    "Callback that measures used and peaked general and GPU memory."

    _order=-20 # Needs to run before the recorder

    def __init__(self, learn:Learner):
        super().__init__(learn)
        assert torch.cuda.is_available(), "pytorch CUDA is required"
        preload_pytorch()

    def peak_monitor_start(self):
        self.peak_monitoring = True

        # start RAM tracing
        tracemalloc.start()

        # this thread samples RAM usage as long as the current epoch of the fit loop is running
        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
        peak_monitor_thread.daemon = True
        peak_monitor_thread.start()

    def peak_monitor_stop(self):
        tracemalloc.stop()
        self.peak_monitoring = False

    def peak_monitor_func(self):
        self.gpu_mem_used_peak = -1

        gpu_id = torch.cuda.current_device()
        gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)

        while True:
            gpu_mem_used = gpu_mem_get_used_fast(gpu_handle)
            self.gpu_mem_used_peak = max(gpu_mem_used, self.gpu_mem_used_peak)
            if not self.peak_monitoring: break
            time.sleep(0.001) # 1msec

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['cpu used',  'peak', 'gpu used',  'peak'])

    def on_epoch_begin(self, **kwargs):
        self.peak_monitor_start()
        self.gpu_before = gpu_mem_get_used_no_cache()

    def on_epoch_end(self, last_metrics, **kwargs):
        cpu_used, cpu_peak =  list(map(lambda x: int(x/2**20), tracemalloc.get_traced_memory()))
        self.peak_monitor_stop()
        gpu_used = gpu_mem_get_used_no_cache() - self.gpu_before
        gpu_peak = self.gpu_mem_used_peak      - self.gpu_before
        # can be negative, due to unreliable peak monitor thread
        if gpu_peak < 0:   gpu_peak = 0
        # since we want the overhead only, subtract delta used if it's positive
        elif gpu_used > 0: gpu_peak -= gpu_used
        # The numbers are deltas in MBs (beginning of the epoch and the end)
        return add_metrics(last_metrics, [cpu_used, cpu_peak, gpu_used, gpu_peak])