dgoot's picture
Update app.py
2c20db4 verified
raw
history blame
2.41 kB
import gradio as gr
import requests
import shutil
import spaces
import torch
from diffusers import AutoencoderKL, StableDiffusionXLImg2ImgPipeline
from loguru import logger
from pathlib import Path
from PIL import Image
from tqdm import tqdm
def download(file: str, url: str):
file_path = Path(file)
if file_path.exists():
return
r = requests.get(url, stream=True)
r.raise_for_status()
temp_path = f"/tmp/{file_path.name}"
with tqdm(
desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True
) as pbar, open(temp_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
f.write(chunk)
pbar.update(len(chunk))
shutil.move(temp_path, file_path)
model_path = "pony-diffusion-v6-xl.safetensors"
download(
model_path,
"https://civitai.com/api/download/models/290640?type=Model&format=SafeTensor&size=pruned&fp=fp16",
)
vae_path = "pony-diffusion-v6-xl.vae.safetensors"
download(
vae_path,
"https://civitai.com/api/download/models/290640?type=VAE&format=SafeTensor",
)
vae = AutoencoderKL.from_single_file(vae_path)
pipe = StableDiffusionXLImg2ImgPipeline.from_single_file(
model_path, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", vae=vae
)
pipe = pipe.to("cuda")
@logger.catch(reraise=True)
@spaces.GPU
def generate(
prompt: str,
init_image: Image.Image,
strength: float,
progress=gr.Progress(),
):
logger.info(
f"Starting image generation: {dict(prompt=prompt, image=init_image, strength=strength)}"
)
# Downscale the image
init_image.thumbnail((1024, 1024))
def progress_callback(pipe, step_index, timestep, callback_kwargs):
logger.trace(
f"Callback: {dict(num_timesteps=pipe.num_timesteps, step_index=step_index, timestep=timestep)}"
)
progress((step_index + 1, pipe.num_timesteps))
return callback_kwargs
images = pipe(
prompt=prompt,
image=init_image,
callback_on_step_end=progress_callback,
strength=strength,
).images
return images[0]
demo = gr.Interface(
fn=generate,
inputs=[
gr.Text(label="Prompt"),
gr.Image(label="Init image", type="pil"),
gr.Slider(label="Strength", minimum=0, maximum=1, value=0.3),
],
outputs=[gr.Image(label="Output")],
)
demo.launch()