kohya_ss / setup /setup_common.py
Ateras's picture
Upload folder using huggingface_hub
fe6327d
raw
history blame
No virus
16.9 kB
import subprocess
import os
import re
import sys
import filecmp
import logging
import shutil
import sysconfig
import datetime
import platform
import pkg_resources
errors = 0 # Define the 'errors' variable before using it
log = logging.getLogger('sd')
# setup console and file logging
def setup_logging(clean=False):
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
from rich.theme import Theme
from rich.logging import RichHandler
from rich.console import Console
from rich.pretty import install as pretty_install
from rich.traceback import install as traceback_install
console = Console(
log_time=True,
log_time_format='%H:%M:%S-%f',
theme=Theme(
{
'traceback.border': 'black',
'traceback.border.syntax_error': 'black',
'inspect.value.border': 'black',
}
),
)
# logging.getLogger("urllib3").setLevel(logging.ERROR)
# logging.getLogger("httpx").setLevel(logging.ERROR)
current_datetime = datetime.datetime.now()
current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S')
log_file = os.path.join(
os.path.dirname(__file__),
f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log',
)
# Create directories if they don't exist
log_directory = os.path.dirname(log_file)
os.makedirs(log_directory, exist_ok=True)
level = logging.INFO
logging.basicConfig(
level=logging.ERROR,
format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s',
filename=log_file,
filemode='a',
encoding='utf-8',
force=True,
)
log.setLevel(
logging.DEBUG
) # log to file is always at level debug for facility `sd`
pretty_install(console=console)
traceback_install(
console=console,
extra_lines=1,
width=console.width,
word_wrap=False,
indent_guides=False,
suppress=[],
)
rh = RichHandler(
show_time=True,
omit_repeated_times=False,
show_level=True,
show_path=False,
markup=False,
rich_tracebacks=True,
log_time_format='%H:%M:%S-%f',
level=level,
console=console,
)
rh.set_name(level)
while log.hasHandlers() and len(log.handlers) > 0:
log.removeHandler(log.handlers[0])
log.addHandler(rh)
def configure_accelerate(run_accelerate=False):
#
# This function was taken and adapted from code written by jstayco
#
from pathlib import Path
def env_var_exists(var_name):
return var_name in os.environ and os.environ[var_name] != ''
log.info('Configuring accelerate...')
source_accelerate_config_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'..',
'config_files',
'accelerate',
'default_config.yaml',
)
if not os.path.exists(source_accelerate_config_file):
if run_accelerate:
run_cmd('accelerate config')
else:
log.warning(
f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.'
)
log.debug(
f'Source accelerate config location: {source_accelerate_config_file}'
)
target_config_location = None
log.debug(
f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, "
f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, "
f"USERPROFILE: {os.environ.get('USERPROFILE')}"
)
if env_var_exists('HF_HOME'):
target_config_location = Path(
os.environ['HF_HOME'], 'accelerate', 'default_config.yaml'
)
elif env_var_exists('LOCALAPPDATA'):
target_config_location = Path(
os.environ['LOCALAPPDATA'],
'huggingface',
'accelerate',
'default_config.yaml',
)
elif env_var_exists('USERPROFILE'):
target_config_location = Path(
os.environ['USERPROFILE'],
'.cache',
'huggingface',
'accelerate',
'default_config.yaml',
)
log.debug(f'Target config location: {target_config_location}')
if target_config_location:
if not target_config_location.is_file():
target_config_location.parent.mkdir(parents=True, exist_ok=True)
log.debug(
f'Target accelerate config location: {target_config_location}'
)
shutil.copyfile(
source_accelerate_config_file, target_config_location
)
log.info(
f'Copied accelerate config file to: {target_config_location}'
)
else:
if run_accelerate:
run_cmd('accelerate config')
else:
log.warning(
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
)
else:
if run_accelerate:
run_cmd('accelerate config')
else:
log.warning(
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
)
def check_torch():
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
# Check for nVidia toolkit or AMD toolkit
if shutil.which('nvidia-smi') is not None or os.path.exists(
os.path.join(
os.environ.get('SystemRoot') or r'C:\Windows',
'System32',
'nvidia-smi.exe',
)
):
log.info('nVidia toolkit detected')
elif shutil.which('rocminfo') is not None or os.path.exists(
'/opt/rocm/bin/rocminfo'
):
log.info('AMD toolkit detected')
else:
log.info('Using CPU-only Torch')
try:
import torch
log.info(f'Torch {torch.__version__}')
# Check if CUDA is available
if not torch.cuda.is_available():
log.warning('Torch reports CUDA not available')
else:
if torch.version.cuda:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
else:
log.warning('Unknown Torch backend')
# Log information about detected GPUs
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
return int(torch.__version__[0])
except Exception as e:
# log.warning(f'Could not load torch: {e}')
return 0
# report current version of code
def check_repo_version(): # pylint: disable=unused-argument
if os.path.exists('.release'):
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
release= file.read()
log.info(f'Version: {release}')
else:
log.debug('Could not read release...')
# execute git command
def git(arg: str, folder: str = None, ignore: bool = False):
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
git_cmd = os.environ.get('GIT', "git")
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
txt = txt.strip()
if result.returncode != 0 and not ignore:
global errors # pylint: disable=global-statement
errors += 1
log.error(f'Error running git: {folder} / {arg}')
if 'or stash them' in txt:
log.error(f'Local changes detected: check log for details...')
log.debug(f'Git output: {txt}')
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False):
# arg = arg.replace('>=', '==')
if not quiet:
log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}')
log.debug(f"Running pip: {arg}")
if show_stdout:
subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ)
else:
result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
txt = txt.strip()
if result.returncode != 0 and not ignore:
global errors # pylint: disable=global-statement
errors += 1
log.error(f'Error running pip: {arg}')
log.debug(f'Pip output: {txt}')
return txt
def installed(package, friendly: str = None):
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
# Remove brackets and their contents from the line using regular expressions
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
package = re.sub(r'\[.*?\]', '', package)
try:
if friendly:
pkgs = friendly.split()
else:
pkgs = [
p
for p in package.split()
if not p.startswith('-') and not p.startswith('=')
]
pkgs = [
p.split('/')[-1] for p in pkgs
] # get only package name if installing from URL
for pkg in pkgs:
if '>=' in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
elif '==' in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
else:
pkg_name, pkg_version = pkg.strip(), None
spec = pkg_resources.working_set.by_key.get(pkg_name, None)
if spec is None:
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
if spec is None:
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
if spec is not None:
version = pkg_resources.get_distribution(pkg_name).version
log.debug(f'Package version found: {pkg_name} {version}')
if pkg_version is not None:
if '>=' in pkg:
ok = version >= pkg_version
else:
ok = version == pkg_version
if not ok:
log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
return False
else:
log.debug(f'Package version not found: {pkg_name}')
return False
return True
except ModuleNotFoundError:
log.debug(f'Package not installed: {pkgs}')
return False
# install package using pip if not already installed
def install(
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
package,
friendly: str = None,
ignore: bool = False,
reinstall: bool = False,
show_stdout: bool = False,
):
# Remove anything after '#' in the package variable
package = package.split('#')[0].strip()
if reinstall:
global quick_allowed # pylint: disable=global-statement
quick_allowed = False
if reinstall or not installed(package, friendly):
pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout)
def process_requirements_line(line, show_stdout: bool = False):
# Remove brackets and their contents from the line using regular expressions
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
package_name = re.sub(r'\[.*?\]', '', line)
install(line, package_name, show_stdout=show_stdout)
def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False):
if check_no_verify_flag:
log.info(f'Verifying modules instalation status from {requirements_file}...')
else:
log.info(f'Installing modules from {requirements_file}...')
with open(requirements_file, 'r', encoding='utf8') as f:
# Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.'
if check_no_verify_flag:
lines = [
line.strip()
for line in f.readlines()
if line.strip() != ''
and not line.startswith('#')
and line is not None
and 'no_verify' not in line
]
else:
lines = [
line.strip()
for line in f.readlines()
if line.strip() != ''
and not line.startswith('#')
and line is not None
]
# Iterate over each line and install the requirements
for line in lines:
# Check if the line starts with '-r' to include another requirements file
if line.startswith('-r'):
# Get the path to the included requirements file
included_file = line[2:].strip()
# Expand the included requirements file recursively
install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout)
else:
process_requirements_line(line, show_stdout=show_stdout)
def ensure_base_requirements():
try:
import rich # pylint: disable=unused-import
except ImportError:
install('--upgrade rich', 'rich')
def run_cmd(run_cmd):
try:
subprocess.run(run_cmd, shell=True, check=False, env=os.environ)
except subprocess.CalledProcessError as e:
print(f'Error occurred while running command: {run_cmd}')
print(f'Error: {e}')
# check python version
def check_python(ignore=True, skip_git=False):
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
supported_minors = [9, 10]
log.info(f'Python {platform.python_version()} on {platform.system()}')
if not (
int(sys.version_info.major) == 3
and int(sys.version_info.minor) in supported_minors
):
log.error(
f'Incompatible Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} required 3.{supported_minors}'
)
if not ignore:
sys.exit(1)
if not skip_git:
git_cmd = os.environ.get('GIT', 'git')
if shutil.which(git_cmd) is None:
log.error('Git not found')
if not ignore:
sys.exit(1)
else:
git_version = git('--version', folder=None, ignore=False)
log.debug(f'Git {git_version.replace("git version", "").strip()}')
def delete_file(file_path):
if os.path.exists(file_path):
os.remove(file_path)
def write_to_file(file_path, content):
try:
with open(file_path, 'w') as file:
file.write(content)
except IOError as e:
print(f'Error occurred while writing to file: {file_path}')
print(f'Error: {e}')
def clear_screen():
# Check the current operating system to execute the correct clear screen command
if os.name == 'nt': # If the operating system is Windows
os.system('cls')
else: # If the operating system is Linux or Mac
os.system('clear')