anycoder-acd412b4 / utils.py
Kaiden423's picture
Update utils.py from anycoder
f2c43cd verified
"""
Utility functions for the Z-Image Turbo LoRA Generator
"""
import os
import json
from pathlib import Path
from datetime import datetime
def get_output_dir():
"""Get or create the output directory for generated LoRAs."""
output_dir = Path("output_loras")
output_dir.mkdir(exist_ok=True)
return output_dir
def validate_images(images):
"""
Validate that images are provided and in correct format.
Args:
images: Single image path or list of image paths
Returns:
tuple: (is_valid, image_list, error_message)
"""
if not images:
return False, [], "No images provided"
# Convert to list if single image
image_list = images if isinstance(images, list) else [images]
if len(image_list) == 0:
return False, [], "No images in list"
# Validate each image exists
for img_path in image_list:
if not os.path.exists(str(img_path)):
return False, [], f"Image not found: {img_path}"
return True, image_list, None
def generate_training_metadata(
project_name,
trigger_word,
num_images,
training_steps,
batch_size,
learning_rate,
resolution,
rank,
alpha,
):
"""
Generate metadata dictionary for the LoRA training.
Args:
project_name: Name of the LoRA project
trigger_word: Word to activate the LoRA
num_images: Number of training images
training_steps: Total training steps
batch_size: Batch size for training
learning_rate: Learning rate
resolution: Image resolution
rank: LoRA rank dimension
alpha: LoRA alpha value
Returns:
dict: Metadata dictionary
"""
return {
"project_name": project_name,
"trigger_word": trigger_word,
"num_images": num_images,
"training_config": {
"steps": training_steps,
"batch_size": batch_size,
"learning_rate": learning_rate,
"resolution": resolution,
"rank": rank,
"alpha": alpha,
},
"model_info": {
"type": "LoRA",
"format": "safetensors",
"compatibility": "Z-Image Turbo",
},
"created_at": datetime.now().isoformat(),
}
def format_log_message(step, message):
"""
Format a log message with timestamp.
Args:
step: Current training step
message: Log message
Returns:
str: Formatted message
"""
timestamp = datetime.now().strftime("%H:%M:%S")
return f"[{timestamp}] Step {step}: {message}"
def cleanup_old_outputs(max_age_hours=24):
"""
Clean up old output files to save disk space.
Args:
max_age_hours: Maximum age in hours for files to keep
"""
import time
output_dir = get_output_dir()
current_time = time.time()
max_age_seconds = max_age_hours * 3600
for item in output_dir.iterdir():
if item.is_file():
file_age = current_time - item.stat().st_mtime
if file_age > max_age_seconds:
item.unlink()
elif item.is_dir():
# Check directory age
dir_age = current_time - item.stat().st_mtime
if dir_age > max_age_seconds:
import shutil
shutil.rmtree(item)
# Example utility for real implementation (not used in demo)
def create_training_command(
images_dir,
output_dir,
trigger_word,
rank=16,
alpha=16,
learning_rate=1e-4,
steps=500,
batch_size=1,
resolution=512,
):
"""
Create a Kohya LoRA training command (for reference).
This would be used in a real implementation with actual LoRA training.
"""
return [
"python", "train_network.py",
"--pretrained_model", "v1-5-pruned.safetensors",
"--train_data_dir", str(images_dir),
"--output_dir", str(output_dir),
"--output_name", "lora",
"--network_module", "networks.lora",
"--network_dim", str(rank),
"--network_alpha", str(alpha),
"--train_batch_size", str(batch_size),
"--learning_rate", str(learning_rate),
"--max_train_steps", str(steps),
"--resolution", f"{resolution},{resolution}",
"--clip_skip", "2",
"--enable_bucket",
"--caption_column", "text",
"--shuffle_caption",
"--weighted_captions",
]
if __name__ == "__main__":
# Test utilities
print("Testing utilities...")
print(f"Output directory: {get_output_dir()}")
metadata = generate_training_metadata(
project_name="test_lora",
trigger_word="test_style",
num_images=10,
training_steps=500,
batch_size=1,
learning_rate=1e-4,
resolution=512,
rank=16,
alpha=16,
)
print(f"Metadata: {json.dumps(metadata, indent=2)}")
print("Utilities test complete!")