import os import torch import spaces import re import yaml import tempfile import subprocess from pathlib import Path from dataclasses import dataclass import gradio as gr from src.flux.xflux_pipeline import XFluxPipeline from huggingface_hub import login import multiprocessing # Set the multiprocessing start method globally multiprocessing.set_start_method('spawn', force=True) hf_token = os.getenv("HF_TOKEN") if hf_token: login(token=hf_token) else: print("No Hugging Face token found.") @dataclass class Config: name: str = "flux-dev" device: str = "cpu" offload: bool = False share: bool = False ckpt_dir: str = "." xflux_pipeline = XFluxPipeline(Config.name, Config.device, Config.offload) xflux_pipeline.to(device='cuda' if torch.cuda.is_available() else 'cpu') @spaces.GPU @torch.inference_mode() def generate(**kwargs): return xflux_pipeline.gradio_generate(**kwargs) def parse_args() -> Config: parser = argparse.ArgumentParser(description="Flux") parser.add_argument("--name", type=str, default="flux-dev", help="Model name") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") parser.add_argument("--share", action="store_true", help="Create a public link to your demo") parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format") args = parser.parse_args() return Config(**vars(args)) def list_dirs(path): if path is None or path == "None" or path == "": return if not os.path.exists(path): path = os.path.dirname(path) if not os.path.exists(path): return if not os.path.isdir(path): path = os.path.dirname(path) def natural_sort_key(s, regex=re.compile("([0-9]+)")): return [ int(text) if text.isdigit() else text.lower() for text in regex.split(s) ] subdirs = [ (item, os.path.join(path, item)) for item in os.listdir(path) if os.path.isdir(os.path.join(path, item)) ] subdirs = [ filename for item, filename in subdirs if item[0] != "." and item not in ["__pycache__"] ] subdirs = sorted(subdirs, key=natural_sort_key) if os.path.dirname(path) != "": dirs = [os.path.dirname(path), path] + subdirs else: dirs = [path] + subdirs if os.sep == "\\": dirs = [d.replace("\\", "/") for d in dirs] for d in dirs: yield d def list_train_data_dirs(): current_train_data_dir = "." return list(list_dirs(current_train_data_dir)) def update_config(d, u): for k, v in u.items(): if isinstance(v, dict): d[k] = update_config(d.get(k, {}), v) else: # convert Gradio components to strings if hasattr(v, 'value'): d[k] = str(v.value) else: try: d[k] = int(v) except (TypeError, ValueError): d[k] = str(v) return d def start_lora_training( data_dir: str, output_dir: str, lr: float, steps: int, rank: int ): inputs = { "data_config": { "img_dir": data_dir, }, "output_dir": output_dir, "learning_rate": lr, "rank": rank, "max_train_steps": steps, } if not os.path.exists(output_dir): os.makedirs(output_dir) print(f"Creating folder {output_dir} for the output checkpoint file...") script_path = Path(__file__).resolve() config_path = script_path.parent / "train_configs" / "test_lora.yaml" with open(config_path, 'r') as file: config = yaml.safe_load(file) config = update_config(config, inputs) print("Config file is updated...", config) with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".yaml") as temp_file: yaml.dump(config, temp_file, default_flow_style=False) tmp_config_path = temp_file.name command = ["accelerate", "launch", "train_flux_lora_deepspeed.py", "--config", tmp_config_path] result = subprocess.run(command, check=True) # rRemove the temporary file after the command is run Path(tmp_config_path).unlink() return result def create_demo( model_type: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False, ckpt_dir: str = "", ): checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) with gr.Blocks() as demo: gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}") with gr.Tab("Inference"): with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") with gr.Accordion("Generation Options", open=False): with gr.Row(): width = gr.Slider(512, 2048, 1024, step=16, label="Width") height = gr.Slider(512, 2048, 1024, step=16, label="Height") neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo") with gr.Row(): num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg") with gr.Row(): guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True) seed = gr.Textbox(-1, label="Seed (-1 for random)") with gr.Accordion("ControlNet Options", open=False): control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type") control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True) local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint", info="Local Path to Controlnet weights (if no, it will be downloaded from HF)" ) controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True) with gr.Accordion("LoRA Options", open=False): lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True) lora_local_path = gr.Dropdown( checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights" ) with gr.Accordion("IP Adapter Options", open=False): image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True) ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale") neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True) neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale") ip_local_path = gr.Dropdown( checkpoints, label="IP Adapter Checkpoint", info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)" ) generate_btn = gr.Button("Generate") with gr.Column(): output_image = gr.Image(label="Generated Image") download_btn = gr.File(label="Download full-resolution") inputs = [prompt, image_prompt, controlnet_image, width, height, guidance, num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, lora_weight, local_path, lora_local_path, ip_local_path ] generate_btn.click( fn=generate, inputs=inputs, outputs=[output_image, download_btn], ) with gr.Tab("LoRA Finetuning"): data_dir = gr.Dropdown(list_train_data_dirs(), label="Training images (directory containing the training images)" ) output_dir = gr.Textbox(label="Output Path", value="lora_checkpoint") with gr.Accordion("Training Options", open=True): lr = gr.Textbox(label="Learning Rate", value="1e-5") steps = gr.Slider(10000, 20000, 20000, step=100, label="Train Steps") rank = gr.Slider(1, 100, 16, step=1, label="LoRa Rank") training_btn = gr.Button("Start training") training_btn.click( fn=start_lora_training, inputs=[data_dir, output_dir, lr, steps, rank], outputs=[], ) return demo if __name__ == "__main__": config = Config() demo = create_demo(config.name, config.device, config.offload, config.ckpt_dir) demo.launch(share=True)