import gradio as gr import requests import shutil import spaces import torch from diffusers import AutoencoderKL, StableDiffusionXLImg2ImgPipeline from loguru import logger from pathlib import Path from PIL import Image from tqdm import tqdm def download(file: str, url: str): file_path = Path(file) if file_path.exists(): return r = requests.get(url, stream=True) r.raise_for_status() temp_path = f"/tmp/{file_path.name}" with tqdm( desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True ) as pbar, open(temp_path, "wb") as f: for chunk in r.iter_content(chunk_size=1024 * 1024): f.write(chunk) pbar.update(len(chunk)) shutil.move(temp_path, file_path) model_path = "pony-diffusion-v6-xl.safetensors" download( model_path, "https://civitai.com/api/download/models/290640?type=Model&format=SafeTensor&size=pruned&fp=fp16", ) vae_path = "pony-diffusion-v6-xl.vae.safetensors" download( vae_path, "https://civitai.com/api/download/models/290640?type=VAE&format=SafeTensor", ) vae = AutoencoderKL.from_single_file(vae_path) pipe = StableDiffusionXLImg2ImgPipeline.from_single_file( model_path, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", vae=vae ) pipe = pipe.to("cuda") @logger.catch(reraise=True) @spaces.GPU def generate( prompt: str, init_image: Image.Image, strength: float, progress=gr.Progress(), ): logger.info( f"Starting image generation: {dict(prompt=prompt, image=init_image, strength=strength)}" ) # Downscale the image init_image.thumbnail((1024, 1024)) def progress_callback(pipe, step_index, timestep, callback_kwargs): logger.trace( f"Callback: {dict(num_timesteps=pipe.num_timesteps, step_index=step_index, timestep=timestep)}" ) progress((step_index + 1, pipe.num_timesteps)) return callback_kwargs images = pipe( prompt=prompt, image=init_image, callback_on_step_end=progress_callback, strength=strength, ).images return images[0] demo = gr.Interface( fn=generate, inputs=[ gr.Text(label="Prompt"), gr.Image(label="Init image", type="pil"), gr.Slider(label="Strength", minimum=0, maximum=1, value=0.3), ], outputs=[gr.Image(label="Output")], ) demo.launch()