shaokun's picture
removed deprecation
0321bd1
raw
history blame contribute delete
No virus
4.53 kB
import gradio as gr
import shutil
import zipfile
import tensorflow as tf
import pandas as pd
import pathlib
import PIL.Image
import os
import subprocess
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
w, h = image.size
if w == h:
return image
elif w > h:
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
class ModelTrainer:
def __init__(self):
self.training_pictures = []
self.training_model = None
def unzip_file(self, zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
extracted_path = zip_file_path.replace('.zip', '')
zip_ref.extractall(extracted_path)
file_names = zip_ref.namelist()
for file_name in file_names:
if file_name.endswith(('.jpeg', '.jpg', '.png')):
self.training_pictures.append(f'{extracted_path}/{file_name}')
def train(self, pretrained_model_name_or_path: str, instance_images: list | None):
output_model_name = 'a-xyz-model'
resolution = 512
repo_dir = pathlib.Path(__file__).parent
subdirs = ['train-instance', 'train-class', 'experiments']
dir_paths = []
for subdir in subdirs:
dir_path = repo_dir / subdir / output_model_name
dir_paths.append(dir_path)
shutil.rmtree(dir_path, ignore_errors=True)
os.makedirs(dir_path, exist_ok=True)
instance_data_dir, class_data_dir, output_dir = dir_paths
for i, temp_path in enumerate(instance_images):
image = PIL.Image.open(temp_path.name)
image = pad_image(image)
image = image.resize((resolution, resolution))
image = image.convert('RGB')
out_path = instance_data_dir / f'{i:03d}.jpg'
image.save(out_path, format='JPEG', quality=100)
command = [
'python', '-u',
'train_dreambooth_cloneofsimo_lora.py',
'--pretrained_model_name_or_path', pretrained_model_name_or_path,
'--instance_data_dir', instance_data_dir,
'--class_data_dir', class_data_dir,
'--resolution', '768',
'--output_dir', output_dir,
'--instance_prompt', 'a photo of a pwsm dog',
'--with_prior_preservation',
'--class_prompt', 'a dog',
'--prior_loss_weight', '1.0',
'--num_class_images', '100',
'--learning_rate', '0.0004',
'--train_batch_size', '1',
'--sample_batch_size', '1',
'--max_train_steps', '400',
'--gradient_accumulation_steps', '1',
'--gradient_checkpointing',
'--train_text_encoder',
'--learning_rate_text', '5e-6',
'--save_steps', '100',
'--seed', '1337',
'--lr_scheduler', 'constant',
'--lr_warmup_steps', '0'
]
result = subprocess.run(command)
return result
def generate_picture(self, row):
num_of_training_steps, learning_rate, checkpoint_steps, abc = row
return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}'
def generate_pictures(self, csv_input):
csv = pd.read_csv(csv_input.name)
result = []
for index, row in csv.iterrows():
result.append(self.generate_picture(row))
return "\n".join(str(item) for item in result)
loader = ModelTrainer()
with gr.Blocks() as demo:
with gr.Box():
instance_images = gr.Files(label='Instance images')
pretrained_model_name_or_path = gr.Textbox(lines=1, label='pretrained_model_name_or_path', placeholder='stabilityai/stable-diffusion-2-1')
output_message = gr.Markdown()
train_button = gr.Button('Train')
train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message])
with gr.Box():
csv_input = gr.File(label='CSV File')
output_message2 = gr.Markdown()
generate_button = gr.Button('Generate Pictures from CSV')
generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2])
demo.launch()