File size: 5,130 Bytes
fb83c5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import sys
import shutil
import argparse
import setup_common
# Get the absolute path of the current file's directory (Kohua_SS project directory)
project_directory = os.path.dirname(os.path.abspath(__file__))
# Check if the "setup" directory is present in the project_directory
if "setup" in project_directory:
# If the "setup" directory is present, move one level up to the parent directory
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Add the project directory to the beginning of the Python search path
sys.path.insert(0, project_directory)
from kohya_gui.custom_logging import setup_logging
# Set up logging
log = setup_logging()
def check_path_with_space():
# Get the current working directory
cwd = os.getcwd()
# Check if the current working directory contains a space
if " " in cwd:
log.error("The path in which this python code is executed contain one or many spaces. This is not supported for running kohya_ss GUI.")
log.error("Please move the repo to a path without spaces, delete the venv folder and run setup.sh again.")
log.error("The current working directory is: " + cwd)
exit(1)
def check_torch():
# Check for toolkit
if shutil.which('nvidia-smi') is not None or os.path.exists(
os.path.join(
os.environ.get('SystemRoot') or r'C:\Windows',
'System32',
'nvidia-smi.exe',
)
):
log.info('nVidia toolkit detected')
elif shutil.which('rocminfo') is not None or os.path.exists(
'/opt/rocm/bin/rocminfo'
):
log.info('AMD toolkit detected')
elif (shutil.which('sycl-ls') is not None
or os.environ.get('ONEAPI_ROOT') is not None
or os.path.exists('/opt/intel/oneapi')):
log.info('Intel OneAPI toolkit detected')
else:
log.info('Using CPU-only Torch')
try:
import torch
try:
# Import IPEX / XPU support
import intel_extension_for_pytorch as ipex
except Exception:
pass
log.info(f'Torch {torch.__version__}')
if torch.cuda.is_available():
if torch.version.cuda:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
else:
log.warning('Unknown Torch backend')
# Log information about detected GPUs
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
# Check if XPU is available
elif hasattr(torch, "xpu") and torch.xpu.is_available():
# Log Intel IPEX version
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
for device in [
torch.xpu.device(i) for i in range(torch.xpu.device_count())
]:
log.info(
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
)
else:
log.warning('Torch reports GPU not available')
return int(torch.__version__[0])
except Exception as e:
log.error(f'Could not load torch: {e}')
sys.exit(1)
def main():
setup_common.check_repo_version()
check_path_with_space()
# Parse command line arguments
parser = argparse.ArgumentParser(
description='Validate that requirements are satisfied.'
)
parser.add_argument(
'-r',
'--requirements',
type=str,
help='Path to the requirements file.',
)
parser.add_argument('--debug', action='store_true', help='Debug on')
args = parser.parse_args()
setup_common.update_submodule()
torch_ver = check_torch()
if not setup_common.check_python_version():
exit(1)
if args.requirements:
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
else:
setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True)
setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True)
if __name__ == '__main__':
main()
|