toolkit / zsh /train_functions.zsh
k4d3's picture
update every stupid script
c2cc76d
raw
history blame
9.39 kB
#!/bin/zsh
# Functions for sd-scripts training scripts
# Executes a training script located at the specified path with the provided arguments.
#
# Parameters:
# - script_path: The path to the training script to be executed.
# - args_array: An array of arguments to be passed to the training script.
#
# Behavior:
# - Changes the current directory to the directory of the script.
# - If the DEBUG environment variable is set, it prints the working directory and the arguments.
# - Executes the script using `wat python` and captures the exit code.
# - Returns to the original directory before exiting.
#
# Returns:
# - The exit code of the executed script.
run_training_script() {
local script_path="$1"
local args_array=("${@:2}") # Get all arguments after the first one
# Store the current directory
local current_dir=$(pwd)
# Change to script directory
local script_dir=$(dirname "$script_path")
local script_name=$(basename "$script_path")
cd "$script_dir" || return 1
# Test if the script exists
[[ ! -f "$script_name" ]] && echo "\e[31mERROR\e[0m: Script not found: $script_name" && return 1
echo "Working directory: $(pwd)\nRunning $script_name arguments:"
for arg in "${args_array[@]}"; do
echo " $arg"
done
if [[ -n "$DEBUG" ]]; then
echo "This was a dry run, exiting." | tee "$OUTPUT_DIR/$NAME/sdscripts.log"
local exit_code=0
else
python "$(basename "$script_path")" "${args_array[@]}" | tee "$OUTPUT_DIR/$NAME/sdscripts.log"
local exit_code=$?
fi
# Return to original directory
cd "$current_dir"
return $exit_code
}
# Sets up default variables for training
#
# Parameters:
# - name: The name of the training run
#
# Returns:
# - Sets the following global variables:
# - DATASET_NAME: Base name without version/steps suffix
# - TRAINING_DIR: Directory containing training data
# - STEPS: Number of training steps
# - OUTPUT_DIR: Base output directory
setup_training_vars() {
local name="$1"
# Declare globals that will be used by the main script
typeset -g DATASET_NAME="${name%-*}"
typeset -g TRAINING_DIR="${TRAINING_DIR:-"${HOME}/datasets/${DATASET_NAME}"}"
typeset -g STEPS=${STEPS:-"${name##*[^0-9]}"}
typeset -g OUTPUT_DIR="${HOME}/output_dir"
echo "\e[35moutput_name\e[0m: $name, \e[35msteps\e[0m: $STEPS, \e[35mtraining_dir\e[0m: $(realpath --relative-to=. $TRAINING_DIR), \e[35moutput_dir\e[0m: $(realpath --relative-to=. "$OUTPUT_DIR/$name")"
echo "\e[35mconda_env\e[0m: $CONDA_PREFIX"
# ===== Validation =====
[[ ! -d "$TRAINING_DIR" ]] && echo "ERROR: Training directory not found" && exit 1
if [[ -d "$OUTPUT_DIR/$name" ]]; then
echo "ERROR: Output directory already exists: $OUTPUT_DIR/$name"
exit 1
fi
}
# Sets up and activates a specified Conda environment.
#
# Parameters:
# - env_name: The name of the Conda environment to activate.
# - conda_path: (Optional) The path to the Conda installation. Defaults to $HOME/miniconda3.
#
# Behavior:
# - Checks if the environment name is provided and if the Conda installation exists.
# - Initializes Conda for the current shell session.
# - Activates the specified Conda environment and verifies its activation.
#
# Returns:
# - 0 on success, or 1 if any error occurs (e.g., missing environment name, Conda installation not found, activation failure).
setup_conda_env() {
local env_name="$1"
[[ -z "$env_name" ]] && echo "\e[31mERROR\e[0m: Environment name required" && return 1
local conda_path="${2:-$HOME/miniconda3}"
[[ ! -d "$conda_path" ]] && echo "\e[31mERROR\e[0m: Conda installation not found at $conda_path" && return 1
# Initialize conda for the shell session
if __conda_setup="$("$conda_path/bin/conda" 'shell.zsh' 'hook' 2>/dev/null)" && eval "$__conda_setup"; then
unset __conda_setup
else
echo "\e[31mERROR\e[0m: Failed to initialize conda environment" && return 1
fi
# Activate conda environment
conda activate "$env_name"
if [ $? -ne 0 ]; then
echo "\e[31mERROR\e[0m: Failed to activate conda environment: $env_name"
return 1
fi
echo "Conda environment: $CONDA_PREFIX"
# Verify environment activation
if ! conda env list | grep -q "^${env_name} "; then
echo "ERROR: Environment $env_name not found"
return 1
fi
}
# Stores the commit hashes of specified Git repositories and copies the script to an output directory.
#
# Parameters:
# - output_dir: The directory where the commit hashes will be stored.
# - repo_path: One or more paths to Git repositories.
#
# Behavior:
# - Creates the output directory if it does not exist.
# - Copies the current script to the output directory.
# - Iterates over each repository path, checking if it is a valid Git repository.
# - Retrieves the current commit SHA for each repository and writes it to an output file.
# - Generates a SHA-1 hash of the script and appends it to the output file.
#
# Returns:
# - 0 on success, or 1 if any error occurs during the process (e.g., Git command failure, not a Git repository).
store_commits_hashes() {
# Construct the output directory path
local output_dir="$OUTPUT_DIR/$NAME"
# Define the path for the output file
local output_file="$output_dir/repos.git"
# Create the output directory if it doesn't exist
[[ ! -d "$output_dir" ]] && mkdir -p "$output_dir"
# Create or truncate the output file
: >"$output_file"
local summary=""
local res=0
for repo_path in "$@"; do
local repo_name=$(basename "$repo_path")
if [[ -d "$repo_path/.git" ]]; then
if local commit_sha=$(git -C "$repo_path" rev-parse HEAD 2>/dev/null); then
# Get the checked-out branch
if local branch_name=$(git -C "$repo_path" rev-parse --abbrev-ref HEAD 2>/dev/null); then
echo "$repo_path: ($branch_name) $commit_sha" >>"$output_file"
summary+="✓ $repo_name: $repo_path ${commit_sha:0:8} ($branch_name)\n"
else
echo "$repo_path: $commit_sha (Failed to get branch)" >>"$output_file"
summary+="⚠️ $repo_name: $repo_path ${commit_sha:0:8} (Failed to get branch)\n"
res=1
fi
else
echo "$repo_path: Git command failed" >>"$output_file"
summary+="⚠️ $repo_name: $repo_path (Git command failed) \n"
res=1
fi
else
echo "$repo_path: Not a git repository" >>"$output_file"
summary+="⚠️ $repo_name: Not a git repository $repo_path\n"
res=1
fi
done
# Copy the script to the output directory
local script_path=$(readlink -f "$ZSH_SCRIPT")
cp "$script_path" "$output_dir/$(basename "$script_path")"
[[ -n "$DEBUG" ]] && echo "Copied $script_path to $output_dir"
# Add script hash with error handling
local script_sha=$(sha1sum "$script_path" | cut -f1 -d' ')
echo "$script_path: $script_sha" >>"$output_file"
summary+="✓ Training script: $ZSH_SCRIPT ${script_sha:0:8}\n"
# Computes hash for "$TRAINING_DIR/config.toml" and "$TRAINING_DIR/sample-prompts.txt" then copy them to the output directory
local config_sha=$(sha1sum "$TRAINING_DIR/config.toml" | cut -f1 -d' ')
local prompts_sha=$(sha1sum "$TRAINING_DIR/sample-prompts.txt" | cut -f1 -d' ')
cp "$TRAINING_DIR/config.toml" "$output_dir/config.toml"
cp "$TRAINING_DIR/sample-prompts.txt" "$output_dir/sample-prompts.txt"
echo "$TRAINING_DIR/config.toml: $config_sha" >>"$output_file"
echo "$TRAINING_DIR/sample-prompts.txt: $prompts_sha" >>"$output_file"
summary+="✓ Training config: $TRAINING_DIR/config.toml ${config_sha:0:8}\n"
summary+="✓ Training prompts: $TRAINING_DIR/sample-prompts.txt ${prompts_sha:0:8}\n"
echo -e "$summary"
return $res
}
get_lycoris_repo() {
python -c """
import importlib.util
import pathlib
spec = importlib.util.find_spec('lycoris')
print(pathlib.Path(spec.origin).parent.parent)
"""
}
cleanup_empty_output() {
[[ -n "$DEBUG" ]] && echo "\e[33mDEBUG\e[0m: Cleanup triggered for $OUTPUT_DIR/$NAME"
# Check if directory exists first
[[ ! -d "$OUTPUT_DIR/$NAME" ]] && {
[[ -n "$DEBUG" ]] && echo "\e[33mDEBUG\e[0m: Output directory doesn't exist, skipping cleanup"
return 0
}
# Use proper glob checking with null_glob modifier (N)
local samples=("$OUTPUT_DIR/$NAME"/**/*.png(N))
local models=("$OUTPUT_DIR/$NAME"/**/*.safetensors(N))
local git_repos=("$OUTPUT_DIR/$NAME"/**/.git(N))
[[ -n "$DEBUG" ]] && {
echo "\e[33mDEBUG\e[0m: Found ${#git_repos[@]} git repositories"
echo "\e[33mDEBUG\e[0m: Found ${#samples[@]} sample files"
echo "\e[33mDEBUG\e[0m: Found ${#models[@]} model files"
}
if [[ ${#samples[@]} -eq 0 && ${#models[@]} -eq 0 && ${#git_repos[@]} -eq 0 ]]; then
if [[ -z "$NO_CLEAN" ]]; then
echo "No samples or model files found, deleting empty output directory"
rm -rf "$OUTPUT_DIR/$NAME"
else
echo "NO_CLEAN set, not deleting directory"
fi
else
echo "Directory contains files, keeping it"
fi
}