File size: 5,130 Bytes
fb83c5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import sys
import shutil
import argparse
import setup_common

# Get the absolute path of the current file's directory (Kohua_SS project directory)
project_directory = os.path.dirname(os.path.abspath(__file__))

# Check if the "setup" directory is present in the project_directory
if "setup" in project_directory:
    # If the "setup" directory is present, move one level up to the parent directory
    project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

# Add the project directory to the beginning of the Python search path
sys.path.insert(0, project_directory)

from kohya_gui.custom_logging import setup_logging

# Set up logging
log = setup_logging()

def check_path_with_space():
    # Get the current working directory
    cwd = os.getcwd()

    # Check if the current working directory contains a space
    if " " in cwd:
        log.error("The path in which this python code is executed contain one or many spaces. This is not supported for running kohya_ss GUI.")
        log.error("Please move the repo to a path without spaces, delete the venv folder and run setup.sh again.")
        log.error("The current working directory is: " + cwd)
        exit(1)

def check_torch():
    # Check for toolkit
    if shutil.which('nvidia-smi') is not None or os.path.exists(
        os.path.join(
            os.environ.get('SystemRoot') or r'C:\Windows',
            'System32',
            'nvidia-smi.exe',
        )
    ):
        log.info('nVidia toolkit detected')
    elif shutil.which('rocminfo') is not None or os.path.exists(
        '/opt/rocm/bin/rocminfo'
    ):
        log.info('AMD toolkit detected')
    elif (shutil.which('sycl-ls') is not None
    or os.environ.get('ONEAPI_ROOT') is not None
    or os.path.exists('/opt/intel/oneapi')):
        log.info('Intel OneAPI toolkit detected')
    else:
        log.info('Using CPU-only Torch')

    try:
        import torch
        try:
            # Import IPEX / XPU support
            import intel_extension_for_pytorch as ipex
        except Exception:
            pass
        log.info(f'Torch {torch.__version__}')

        if torch.cuda.is_available():
            if torch.version.cuda:
                # Log nVidia CUDA and cuDNN versions
                log.info(
                    f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
                )
            elif torch.version.hip:
                # Log AMD ROCm HIP version
                log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
            else:
                log.warning('Unknown Torch backend')

            # Log information about detected GPUs
            for device in [
                torch.cuda.device(i) for i in range(torch.cuda.device_count())
            ]:
                log.info(
                    f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
                )
        # Check if XPU is available
        elif hasattr(torch, "xpu") and torch.xpu.is_available():
            # Log Intel IPEX version
            log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
            for device in [
                torch.xpu.device(i) for i in range(torch.xpu.device_count())
            ]:
                log.info(
                    f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
                )
        else:
            log.warning('Torch reports GPU not available')
        
        return int(torch.__version__[0])
    except Exception as e:
        log.error(f'Could not load torch: {e}')
        sys.exit(1)
        
def main():
    setup_common.check_repo_version()
    
    check_path_with_space()
    
    # Parse command line arguments
    parser = argparse.ArgumentParser(
        description='Validate that requirements are satisfied.'
    )
    parser.add_argument(
        '-r',
        '--requirements',
        type=str,
        help='Path to the requirements file.',
    )
    parser.add_argument('--debug', action='store_true', help='Debug on')
    args = parser.parse_args()
    
    setup_common.update_submodule()

    torch_ver = check_torch()
    
    if not setup_common.check_python_version():
        exit(1)
    
    if args.requirements:
        setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
    else:
        setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True)
        setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True)

if __name__ == '__main__':
    main()