import os import torch import gradio as gr # from PIL import Image from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline from kolors.models.modeling_chatglm import ChatGLMModel from kolors.models.tokenization_chatglm import ChatGLMTokenizer from diffusers import UNet2DConditionModel, AutoencoderKL from diffusers import EulerDiscreteScheduler root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Initialize global variables for models and pipeline text_encoder = None tokenizer = None vae = None scheduler = None unet = None pipe = None def load_models(): global text_encoder, tokenizer, vae, scheduler, unet, pipe if text_encoder is None: ckpt_dir = f'{root_dir}/weights/Kolors' # Load the text encoder on CPU (this speeds stuff up 2x) text_encoder = ChatGLMModel.from_pretrained( f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).to('cpu').half() tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') # Load the VAE and UNet on GPU vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to('cuda') scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to('cuda') # Prepare the pipeline pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, force_zeros_for_empty_prompt=False) pipe = pipe.to("cuda") pipe.enable_model_cpu_offload() # Enable offloading to balance CPU/GPU usage def infer(prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt): load_models() if use_random_seed: seed = torch.randint(0, 2**32 - 1, (1,)).item() generator = torch.Generator(pipe.device).manual_seed(seed) images = pipe( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, generator=generator ).images saved_images = [] output_dir = f'{root_dir}/scripts/outputs' os.makedirs(output_dir, exist_ok=True) for i, image in enumerate(images): file_path = os.path.join(output_dir, 'sample_test.jpg') base_name, ext = os.path.splitext(file_path) counter = 1 while os.path.exists(file_path): file_path = f"{base_name}_{counter}{ext}" counter += 1 image.save(file_path) saved_images.append(file_path) return saved_images def gradio_interface(): with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr.Markdown("## Kolors: Diffusion Model Gradio Interface") prompt = gr.Textbox(label="Prompt") use_random_seed = gr.Checkbox(label="Use Random Seed", value=True) seed = gr.Slider(minimum=0, maximum=2**32 - 1, step=1, label="Seed", randomize=True, visible=False) use_random_seed.change(lambda x: gr.update(visible=not x), use_random_seed, seed) height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=1024) width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=1024) num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps", value=50) guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) num_images_per_prompt = gr.Slider(minimum=1, maximum=10, step=1, label="Images per Prompt", value=1) btn = gr.Button("Generate Image") with gr.Column(): output_images = gr.Gallery(label="Output Images", elem_id="output_gallery") btn.click( fn=infer, inputs=[prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt], outputs=output_images ) return demo if __name__ == '__main__': gradio_interface().launch()