Spaces:
Runtime error
Runtime error
import gradio as gr | |
import shutil | |
import zipfile | |
import tensorflow as tf | |
import pandas as pd | |
import pathlib | |
import PIL.Image | |
import os | |
import subprocess | |
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: | |
w, h = image.size | |
if w == h: | |
return image | |
elif w > h: | |
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) | |
new_image.paste(image, (0, (w - h) // 2)) | |
return new_image | |
else: | |
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) | |
new_image.paste(image, ((h - w) // 2, 0)) | |
return new_image | |
class ModelTrainer: | |
def __init__(self): | |
self.training_pictures = [] | |
self.training_model = None | |
def unzip_file(self, zip_file_path): | |
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
extracted_path = zip_file_path.replace('.zip', '') | |
zip_ref.extractall(extracted_path) | |
file_names = zip_ref.namelist() | |
for file_name in file_names: | |
if file_name.endswith(('.jpeg', '.jpg', '.png')): | |
self.training_pictures.append(f'{extracted_path}/{file_name}') | |
def train(self, pretrained_model_name_or_path: str, instance_images: list | None): | |
output_model_name = 'a-xyz-model' | |
resolution = 512 | |
repo_dir = pathlib.Path(__file__).parent | |
subdirs = ['train-instance', 'train-class', 'experiments'] | |
dir_paths = [] | |
for subdir in subdirs: | |
dir_path = repo_dir / subdir / output_model_name | |
dir_paths.append(dir_path) | |
shutil.rmtree(dir_path, ignore_errors=True) | |
os.makedirs(dir_path, exist_ok=True) | |
instance_data_dir, class_data_dir, output_dir = dir_paths | |
for i, temp_path in enumerate(instance_images): | |
image = PIL.Image.open(temp_path.name) | |
image = pad_image(image) | |
image = image.resize((resolution, resolution)) | |
image = image.convert('RGB') | |
out_path = instance_data_dir / f'{i:03d}.jpg' | |
image.save(out_path, format='JPEG', quality=100) | |
command = [ | |
'python', '-u', | |
'train_dreambooth_cloneofsimo_lora.py', | |
'--pretrained_model_name_or_path', pretrained_model_name_or_path, | |
'--instance_data_dir', instance_data_dir, | |
'--class_data_dir', class_data_dir, | |
'--resolution', '768', | |
'--output_dir', output_dir, | |
'--instance_prompt', 'a photo of a pwsm dog', | |
'--with_prior_preservation', | |
'--class_prompt', 'a dog', | |
'--prior_loss_weight', '1.0', | |
'--num_class_images', '100', | |
'--learning_rate', '0.0004', | |
'--train_batch_size', '1', | |
'--sample_batch_size', '1', | |
'--max_train_steps', '400', | |
'--gradient_accumulation_steps', '1', | |
'--gradient_checkpointing', | |
'--train_text_encoder', | |
'--learning_rate_text', '5e-6', | |
'--save_steps', '100', | |
'--seed', '1337', | |
'--lr_scheduler', 'constant', | |
'--lr_warmup_steps', '0' | |
] | |
result = subprocess.run(command) | |
return result | |
def generate_picture(self, row): | |
num_of_training_steps, learning_rate, checkpoint_steps, abc = row | |
return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}' | |
def generate_pictures(self, csv_input): | |
csv = pd.read_csv(csv_input.name) | |
result = [] | |
for index, row in csv.iterrows(): | |
result.append(self.generate_picture(row)) | |
return "\n".join(str(item) for item in result) | |
loader = ModelTrainer() | |
with gr.Blocks() as demo: | |
with gr.Box(): | |
instance_images = gr.Files(label='Instance images') | |
pretrained_model_name_or_path = gr.inputs.Textbox(lines=1, label='pretrained_model_name_or_path', default='stabilityai/stable-diffusion-2-1') | |
output_message = gr.Markdown() | |
train_button = gr.Button('Train') | |
train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message]) | |
with gr.Box(): | |
csv_input = gr.inputs.File(label='CSV File') | |
output_message2 = gr.Markdown() | |
generate_button = gr.Button('Generate Pictures from CSV') | |
generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2]) | |
demo.launch() | |