File size: 5,602 Bytes
b585c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import sys
import time
import traceback
import webbrowser

# uncomment below to ensure CPU install only uses CPU
# os.environ['CUDA_VISIBLE_DEVICES'] = ''

print('__file__: %s' % __file__)
path1 = os.path.dirname(os.path.abspath(__file__))
sys.path.append(path1)
base_path = os.path.dirname(path1)
sys.path.append(base_path)
os.environ['PYTHONPATH'] = path1
print('path1', path1, flush=True)

os.environ['NLTK_DATA'] = os.path.join(base_path, './nltk_data')
path_list = [os.environ['PATH'],
                     os.path.join(base_path, 'poppler/Library/bin/'),
                     os.path.join(base_path, 'poppler/Library/lib/'),
                     os.path.join(base_path, 'Tesseract-OCR'),
                     os.path.join(base_path, 'ms-playwright'),
                     os.path.join(base_path, 'ms-playwright/chromium-1076/chrome-win'),
                     os.path.join(base_path, 'ms-playwright/ffmpeg-1009'),
                     os.path.join(base_path, 'ms-playwright/firefox-1422/firefox'),
                     os.path.join(base_path, 'ms-playwright/webkit-1883'),
                     os.path.join(base_path, 'rubberband/')]
os.environ['PATH'] = ';'.join(path_list)
print(os.environ['PATH'])

import shutil, errno


def copy_tree(src, dst):
    try:
        shutil.copytree(src, dst)
    except OSError as exc: # python >2.5
        if exc.errno in (errno.ENOTDIR, errno.EINVAL):
            shutil.copy(src, dst)
        else: raise


def setup_paths():
    for sub in ['src', 'iterators', 'gradio_utils', 'metrics', 'models', '.']:
        path2 = os.path.join(base_path, '..', sub)
        if os.path.isdir(path2):
            if sub == 'models' and os.path.isfile(os.path.join(path2, 'human.jpg')):
                os.environ['H2OGPT_MODEL_BASE'] = path2
            sys.path.append(path2)
        print(path2, flush=True)

        path2 = os.path.join(path1, '..', sub)
        if os.path.isdir(path2):
            if sub == 'models' and os.path.isfile(os.path.join(path2, 'human.jpg')):
                os.environ['H2OGPT_MODEL_BASE'] = path2
            sys.path.append(path2)
        print(path2, flush=True)

    # for app, avoid forbidden for web access
    if os.getenv('H2OGPT_MODEL_BASE'):
        base0 = os.environ['H2OGPT_MODEL_BASE']
        if 'Programs' in os.environ['H2OGPT_MODEL_BASE']:
            os.environ['H2OGPT_MODEL_BASE'] = os.environ['H2OGPT_MODEL_BASE'].replace('Programs', 'Temp/gradio/')
            shutil.rmtree(os.environ['H2OGPT_MODEL_BASE'])
            if os.path.isfile(os.path.join(base0, 'human.jpg')):
                copy_tree(base0, os.environ['H2OGPT_MODEL_BASE'])


from importlib.metadata import distribution, PackageNotFoundError

try:
    dtorch = distribution('torch')
    assert dtorch is not None
    have_torch = True
    torch_version = dtorch.version
except (PackageNotFoundError, AssertionError):
    have_torch = False
    torch_version = ''


def _main():
    setup_paths()
    os.environ['h2ogpt_block_gradio_exit'] = 'False'
    os.environ['h2ogpt_score_model'] = ''

    try:
        from pynvml import nvmlInit, nvmlDeviceGetCount
        nvmlInit()
        deviceCount = nvmlDeviceGetCount()
    except Exception as e:
        print("No GPUs detected by NVML: %s" % str(e))
        deviceCount = 0

    need_get_gpu_torch = False
    if have_torch and deviceCount > 0:
        if '+cu' not in torch_version:
            need_get_gpu_torch = True
    elif not have_torch and deviceCount > 0:
        need_get_gpu_torch = True

    print("Torch Status: have torch: %s need get gpu torch: %s CVD: %s GPUs: %s" % (have_torch, need_get_gpu_torch, os.getenv('CUDA_VISIBLE_DEVICES'), deviceCount))

    auto_install_torch_gpu = False

    import sys
    if auto_install_torch_gpu and (not have_torch or need_get_gpu_torch) and sys.platform == "win32":
        print("Installing Torch")
        # for one-click, don't have torch installed, install now
        import subprocess
        import sys

        def install(package):
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])

        if os.getenv('TORCH_WHEEL'):
            print("Installing Torch from %s" % os.getenv('TORCH_WHEEL'))
            install(os.getenv('TORCH_WHEEL'))
        else:
            if need_get_gpu_torch:
                wheel_file = "https://h2o-release.s3.amazonaws.com/h2ogpt/torch-2.1.2%2Bcu118-cp310-cp310-win_amd64.whl"
                print("Installing Torch from %s" % wheel_file)
                install(wheel_file)
            # assume cpu torch part of install
            #else:
            #   wheel_file = "https://h2o-release.s3.amazonaws.com/h2ogpt/torch-2.1.2-cp310-cp310-win_amd64.whl"
            #    print("Installing Torch from %s" % wheel_file)
            #    install(wheel_file)
        import importlib
        importlib.invalidate_caches()
        import pkg_resources
        importlib.reload(pkg_resources)  # re-load because otherwise cache would be bad

    from generate import entrypoint_main as main_h2ogpt
    main_h2ogpt()

    server_name = os.getenv('h2ogpt_server_name', os.getenv('H2OGPT_SERVER_NAME', 'localhost'))
    server_port = os.getenv('GRADIO_SERVER_PORT', str(7860))

    url = "http://%s:%s" % (server_name, server_port)
    webbrowser.open(url)

    while True:
        time.sleep(10000)


def main():
    try:
        _main()
    except BaseException as e:
        with open('h2ogpt_exception.log', 'at') as f:
            f.write(traceback.format_exc())
        time.sleep(10)
        raise
    time.sleep(10)


if __name__ == "__main__":
    main()