#!/usr/bin/env python from __future__ import annotations import os import gradio as gr from constants import UploadTarget from inference import InferencePipeline from trainer import Trainer def create_training_demo( trainer: Trainer, pipe: InferencePipeline | None = None, disable_run_button: bool = False ) -> gr.Blocks: def read_log() -> str: with open(trainer.log_file) as f: lines = f.readlines() return "".join(lines[-10:]) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): with gr.Box(): gr.Markdown("Training Data") training_video = gr.File(label="Training video") training_prompt = gr.Textbox(label="Training prompt", max_lines=1, placeholder="A man is surfing") gr.Markdown( """ - Upload a video and write a `Training Prompt` that describes the video. """ ) with gr.Column(): with gr.Box(): gr.Markdown("Training Parameters") with gr.Row(): base_model = gr.Text(label="Base Model", value="CompVis/stable-diffusion-v1-4", max_lines=1) resolution = gr.Dropdown( choices=["512", "768"], value="512", label="Resolution", visible=False ) hf_token = gr.Text( label="Hugging Face Write Token", type="password", visible=os.getenv("HF_TOKEN") is None ) with gr.Accordion(label="Advanced options", open=False): num_training_steps = gr.Number(label="Number of Training Steps", value=300, precision=0) learning_rate = gr.Number(label="Learning Rate", value=0.000035) gradient_accumulation = gr.Number( label="Number of Gradient Accumulation", value=1, precision=0 ) seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, randomize=True, value=0) fp16 = gr.Checkbox(label="FP16", value=True) use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=False) checkpointing_steps = gr.Number(label="Checkpointing Steps", value=1000, precision=0) validation_epochs = gr.Number(label="Validation Epochs", value=100, precision=0) gr.Markdown( """ - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library. - Expected time to train a model for 300 steps: ~20 minutes with T4 - You can check the training status by pressing the "Open logs" button if you are running this on your Space. """ ) with gr.Row(): with gr.Column(): gr.Markdown("Output Model") output_model_name = gr.Text(label="Name of your model", placeholder="The surfer man", max_lines=1) validation_prompt = gr.Text( label="Validation Prompt", placeholder="prompt to test the model, e.g: a dog is surfing" ) with gr.Column(): gr.Markdown("Upload Settings") with gr.Row(): upload_to_hub = gr.Checkbox(label="Upload model to Hub", value=True) use_private_repo = gr.Checkbox(label="Private", value=True) delete_existing_repo = gr.Checkbox(label="Delete existing repo of the same name", value=False) upload_to = gr.Radio( label="Upload to", choices=[_.value for _ in UploadTarget], value=UploadTarget.MODEL_LIBRARY.value, ) pause_space_after_training = gr.Checkbox( label="Pause this Space after training", value=False, interactive=bool(os.getenv("SPACE_ID")), visible=False, ) run_button = gr.Button("Start Training", interactive=not disable_run_button) with gr.Box(): gr.Text(label="Log", value=read_log, lines=10, max_lines=10, every=1) if pipe is not None: run_button.click(fn=pipe.clear) run_button.click( fn=trainer.run, inputs=[ training_video, training_prompt, output_model_name, delete_existing_repo, validation_prompt, base_model, resolution, num_training_steps, learning_rate, gradient_accumulation, seed, fp16, use_8bit_adam, checkpointing_steps, validation_epochs, upload_to_hub, use_private_repo, delete_existing_repo, upload_to, pause_space_after_training, hf_token, ], ) return demo if __name__ == "__main__": trainer = Trainer() demo = create_training_demo(trainer) demo.queue(api_open=False, max_size=1).launch()