File size: 5,022 Bytes
8b0ae10
72ff821
 
6148b7c
 
72ff821
82f1bf5
 
9279c83
 
 
8b0ae10
40a8f4e
966795b
87a0e23
82f1bf5
 
40a8f4e
 
 
72ff821
40a8f4e
82f1bf5
fd15ecb
8b0ae10
03b5741
49ce4b9
8b0ae10
 
49ce4b9
9ee06c7
8b0ae10
9ee06c7
883e16a
8b0ae10
 
883e16a
87a0e23
8b0ae10
90c428d
a1771a7
 
87a0e23
9279c83
 
 
 
 
 
a5d7977
40a8f4e
 
 
 
 
 
 
8b0ae10
 
 
 
 
 
72ff821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9279c83
8b0ae10
 
6148b7c
9279c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6148b7c
 
 
 
 
8b0ae10
 
6148b7c
8b0ae10
 
6148b7c
 
 
9279c83
 
 
6148b7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import importlib
import os
import subprocess
import psutil
import math

from typing import Any, Dict, List, Optional, Tuple, Union

from numba import cuda
import nvidia_smi

from .dynamic_import import dynamic_import
from .config import Config
from .utils.lru_cache import LRUCache


class Global:
    """
    A singleton class holding global states.
    """

    version: Union[str, None] = None

    base_model_name: str = ""
    tokenizer_name: Union[str, None] = None

    # Functions
    inference_generate_fn: Any
    finetune_train_fn: Any

    # Training Control
    should_stop_training: bool = False

    # Generation Control
    should_stop_generating: bool = False
    generation_force_stopped_at: Union[float, None] = None

    # Model related
    loaded_models = LRUCache(1)
    loaded_tokenizers = LRUCache(1)
    new_base_model_that_is_ready_to_be_used = None
    name_of_new_base_model_that_is_ready_to_be_used = None

    # GPU Info
    gpu_cc = None  # GPU compute capability
    gpu_sms = None  # GPU total number of SMs
    gpu_total_cores = None  # GPU total cores
    gpu_total_memory = None


def initialize_global():
    Global.base_model_name = Config.default_base_model_name
    commit_hash = get_git_commit_hash()

    if commit_hash:
        Global.version = commit_hash[:8]

    if not Config.ui_dev_mode:
        ModelLRUCache = dynamic_import('.utils.model_lru_cache').ModelLRUCache
        Global.loaded_models = ModelLRUCache(1)
        Global.inference_generate_fn = dynamic_import('.lib.inference').generate
        Global.finetune_train_fn = dynamic_import('.lib.finetune').train
        load_gpu_info()


def get_package_dir():
    current_file_path = os.path.abspath(__file__)
    parent_directory_path = os.path.dirname(current_file_path)
    return os.path.abspath(parent_directory_path)


def get_git_commit_hash():
    try:
        original_cwd = os.getcwd()
        project_dir = get_package_dir()
        try:
            os.chdir(project_dir)
            commit_hash = subprocess.check_output(
                ['git', 'rev-parse', 'HEAD']).strip().decode('utf-8')
            return commit_hash
        except Exception as e:
            print(f"Cannot get git commit hash: {e}")
        finally:
            os.chdir(original_cwd)
    except Exception as e:
        print(f"Cannot get git commit hash: {e}")


def load_gpu_info():
    # cuda = importlib.import_module('numba').cuda
    # nvidia_smi = importlib.import_module('nvidia_smi')
    print("")
    try:
        cc_cores_per_SM_dict = {
            (2, 0): 32,
            (2, 1): 48,
            (3, 0): 192,
            (3, 5): 192,
            (3, 7): 192,
            (5, 0): 128,
            (5, 2): 128,
            (6, 0): 64,
            (6, 1): 128,
            (7, 0): 64,
            (7, 5): 64,
            (8, 0): 64,
            (8, 6): 128,
            (8, 9): 128,
            (9, 0): 128
        }
        # the above dictionary should result in a value of "None" if a cc match
        # is not found.  The dictionary needs to be extended as new devices become
        # available, and currently does not account for all Jetson devices
        device = cuda.get_current_device()
        device_sms = getattr(device, 'MULTIPROCESSOR_COUNT')
        device_cc = device.compute_capability
        cores_per_sm = cc_cores_per_SM_dict.get(device_cc)
        total_cores = cores_per_sm*device_sms
        print("GPU compute capability: ", device_cc)
        print("GPU total number of SMs: ", device_sms)
        print("GPU total cores: ", total_cores)
        Global.gpu_cc = device_cc
        Global.gpu_sms = device_sms
        Global.gpu_total_cores = total_cores

        nvidia_smi.nvmlInit()
        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
        info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
        total_memory = info.total

        total_memory_mb = total_memory / (1024 ** 2)
        total_memory_gb = total_memory / (1024 ** 3)

        # Print the memory size
        print(
            f"GPU total memory: {total_memory} bytes ({total_memory_mb:.2f} MB) ({total_memory_gb:.2f} GB)")
        Global.gpu_total_memory = total_memory

        available_cpu_ram = psutil.virtual_memory().available
        available_cpu_ram_mb = available_cpu_ram / (1024 ** 2)
        available_cpu_ram_gb = available_cpu_ram / (1024 ** 3)
        print(
            f"CPU available memory: {available_cpu_ram} bytes ({available_cpu_ram_mb:.2f} MB) ({available_cpu_ram_gb:.2f} GB)")
        preserve_loaded_models_count = math.floor(
            (available_cpu_ram * 0.8) / total_memory) - 1
        if preserve_loaded_models_count > 1:
            print(
                f"Will keep {preserve_loaded_models_count} offloaded models in CPU RAM.")
            Global.loaded_models = ModelLRUCache(preserve_loaded_models_count)
            Global.loaded_tokenizers = LRUCache(preserve_loaded_models_count)

    except Exception as e:
        print(f"Notice: cannot get GPU info: {e}")

    print("")