import subprocess, os from Helpers import name_formatter, weights_dir, capture_message import matplotlib.pyplot as plt import matplotlib.image as mpimg class Training: def __init__(self, data): self.app_path = os.getenv('APP_DIR') self.model_name = os.getenv('MODEL_NAME') self.instance_dir = os.getenv('INSTANCE_DIR') self.output_dir = os.getenv('OUTPUT_DIR') self.class_dir = os.getenv('CLASS_DIR') self.instance_prompt = "photo of %s person" % (name_formatter(data['name'])) self.class_prompt = "photo of a person" args = [ self.app_path + 'training.sh', self.model_name, self.instance_dir, self.output_dir, self.class_dir, self.instance_prompt, self.class_prompt, ] subprocess.run(['chmod', '+x', self.app_path + 'training.sh'], cwd=self.app_path) process = subprocess.Popen(args, cwd=self.app_path) process.wait() self.generate_grid_image() self.convert_diffusers_to_original_stable_diffusion() def generate_grid_image(self): capture_message('Training: Generate grid image') folders = sorted([f for f in os.listdir(self.output_dir) if f != "0"], key=lambda x: int(x)) row = len(folders) col = len(os.listdir(os.path.join(self.output_dir, folders[0], "samples"))) scale = 4 fig, axes = plt.subplots(row, col, figsize=(col * scale, row * scale), gridspec_kw={'hspace': 0, 'wspace': 0}) for i, folder in enumerate(folders): folder_path = os.path.join(self.output_dir, folder) image_folder = os.path.join(folder_path, "samples") images = [f for f in os.listdir(image_folder)] for j, image in enumerate(images): if row == 1: currAxes = axes[j] else: currAxes = axes[i, j] if i == 0: currAxes.set_title(f"Image {j}") if j == 0: currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes) image_path = os.path.join(image_folder, image) img = mpimg.imread(image_path) currAxes.imshow(img, cmap='gray') currAxes.axis('off') plt.tight_layout() plt.savefig('grid.png', dpi=72) def convert_diffusers_to_original_stable_diffusion(self): capture_message('Training: Convert diffusers to original stable diffusion') args = [ self.app_path + 'convert_du_to_sd.sh', weights_dir() + '/model.ckpt', self.output_dir, ] subprocess.run(['chmod', '+x', self.app_path + 'convert_du_to_sd.sh'], cwd=self.app_path) process = subprocess.Popen(args, cwd=self.app_path) process.wait()