flx-upscale / app.py
fantaxy's picture
Update app.py
3960c92 verified
raw
history blame
5.49 kB
import logging
import random
import warnings
import os
import gradio as gr
import numpy as np
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
import gc
# Clear memory
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
css = """
#col-container {
margin: 0 auto;
max-width: 512px;
}
"""
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
huggingface_token = os.getenv("HF_TOKEN")
# Modified model configuration
model_config = {
"low_cpu_mem_usage": True,
"torch_dtype": dtype,
"use_safetensors": False, # Disabled safetensors
}
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes", "*.bin"],
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load models with modified configuration
try:
controlnet = FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
**model_config
)
controlnet.to(device)
pipe = FluxControlNetPipeline.from_pretrained(
model_path,
controlnet=controlnet,
**model_config
)
pipe.to(device)
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
# Enable optimizations
pipe.enable_attention_slicing(1)
pipe.enable_vae_slicing()
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 64 * 64
def process_input(input_image, upscale_factor):
input_image = input_image.convert('RGB')
w, h = input_image.size
max_size = int(np.sqrt(MAX_PIXEL_BUDGET))
new_w = min(w, max_size)
new_h = min(h, max_size)
input_image = input_image.resize((new_w, new_h), Image.LANCZOS)
w = new_w - new_w % 8
h = new_h - new_h % 8
return input_image.resize((w, h)), w, h
def infer(
seed,
randomize_seed,
input_image,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
progress=gr.Progress(track_tqdm=True),
):
try:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if randomize_seed:
seed = random.randint(0, MAX_SEED)
input_image, w, h = process_input(input_image, upscale_factor)
with torch.inference_mode():
generator = torch.Generator(device=device).manual_seed(seed)
image = pipe(
prompt="",
control_image=input_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
guidance_scale=1.5,
height=h,
width=w,
generator=generator,
).images[0]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return [input_image, image, seed]
except Exception as e:
gr.Error(f"Error: {str(e)}")
return None
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Row():
run_button = gr.Button(value="Run")
with gr.Row():
with gr.Column(scale=4):
input_im = gr.Image(label="Input Image", type="pil")
with gr.Column(scale=1):
num_inference_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
upscale_factor = gr.Slider(
label="Scale",
minimum=1,
maximum=1,
step=1,
value=1,
)
controlnet_conditioning_scale = gr.Slider(
label="Control Scale",
minimum=0.1,
maximum=0.3,
step=0.1,
value=0.2,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Random Seed", value=True)
with gr.Row():
result = ImageSlider(label="Result", type="pil", interactive=True)
current_dir = os.path.dirname(os.path.abspath(__file__))
examples = gr.Examples(
examples=[
[42, False, os.path.join(current_dir, "z1.webp"), 5, 1, 0.2],
[42, False, os.path.join(current_dir, "z2.webp"), 5, 1, 0.2],
],
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
],
fn=infer,
outputs=result,
cache_examples=False,
)
gr.on(
[run_button.click],
fn=infer,
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
],
outputs=result,
show_api=False,
)
# Launch configuration
demo.queue(max_size=1).launch(
share=False,
debug=True,
show_error=True,
max_threads=1,
enable_queue=True,
cache_examples=False,
quiet=True,
)