kohya_ss / setup /validate_requirements.py
Ateras's picture
Upload folder using huggingface_hub
fe6327d
raw
history blame
No virus
3.51 kB
import os
import re
import sys
import shutil
import argparse
import setup_common
# Get the absolute path of the current file's directory (Kohua_SS project directory)
project_directory = os.path.dirname(os.path.abspath(__file__))
# Check if the "setup" directory is present in the project_directory
if "setup" in project_directory:
# If the "setup" directory is present, move one level up to the parent directory
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Add the project directory to the beginning of the Python search path
sys.path.insert(0, project_directory)
from library.custom_logging import setup_logging
# Set up logging
log = setup_logging()
def check_torch():
# Check for nVidia toolkit or AMD toolkit
if shutil.which('nvidia-smi') is not None or os.path.exists(
os.path.join(
os.environ.get('SystemRoot') or r'C:\Windows',
'System32',
'nvidia-smi.exe',
)
):
log.info('nVidia toolkit detected')
elif shutil.which('rocminfo') is not None or os.path.exists(
'/opt/rocm/bin/rocminfo'
):
log.info('AMD toolkit detected')
else:
log.info('Using CPU-only Torch')
try:
import torch
log.info(f'Torch {torch.__version__}')
# Check if CUDA is available
if not torch.cuda.is_available():
log.warning('Torch reports CUDA not available')
else:
if torch.version.cuda:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
else:
log.warning('Unknown Torch backend')
# Log information about detected GPUs
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
return int(torch.__version__[0])
except Exception as e:
log.error(f'Could not load torch: {e}')
sys.exit(1)
def main():
setup_common.check_repo_version()
# Parse command line arguments
parser = argparse.ArgumentParser(
description='Validate that requirements are satisfied.'
)
parser.add_argument(
'-r',
'--requirements',
type=str,
help='Path to the requirements file.',
)
parser.add_argument('--debug', action='store_true', help='Debug on')
args = parser.parse_args()
torch_ver = check_torch()
if args.requirements:
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
else:
if torch_ver == 1:
setup_common.install_requirements('requirements_windows_torch1.txt', check_no_verify_flag=True)
else:
setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=True)
if __name__ == '__main__':
main()