lichorosario's picture
Refactorizar función de refinamiento de imágenes en app.py
ea71148
raw
history blame
14.7 kB
import os
import uuid
import gradio as gr
import json
from gradio_client import Client, handle_file
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import InferenceClient
from loadimg import load_img
with open('loras.json', 'r') as f:
loras = json.load(f)
job = None
# Verificar las URLs de los modelos
custom_model_url = "https://fffiloni-sd-xl-custom-model.hf.space"
tile_upscaler_url = "https://gokaygokay-tileupscalerv2.hf.space"
client_custom_model = None
client_tile_upscaler = None
# try:
# client_custom_model = Client(custom_model_url)
# print(f"Loaded custom model from {custom_model_url}")
# except ValueError as e:
# print(f"Failed to load custom model: {e}")
# try:
# client_tile_upscaler = Client(tile_upscaler_url)
# print(f"Loaded custom model from {tile_upscaler_url}")
# except ValueError as e:
# print(f"Failed to load custom model: {e}")
def infer(selected_index, prompt, style_prompt, inf_steps, guidance_scale, width, height, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
try:
global job
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = loras[selected_index]
custom_model = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
global client_custom_model
if client_custom_model is None:
try:
client_custom_model = Client(custom_model_url)
print(f"Loaded custom model from {custom_model_url}")
except ValueError as e:
print(f"Failed to load custom model: {e}")
client_custom_model = None
raise gr.Error("Failed to load client for " + custom_model_url)
try:
result = client_custom_model.submit(
custom_model=custom_model,
api_name="/load_model"
)
except ValueError as e:
raise gr.Error(e)
weight_name = result.result()[2]['value']
if trigger_word and prompt.startswith(trigger_word):
prompt = prompt[len(trigger_word+'. '):].lstrip()
if style_prompt and prompt.endswith(style_prompt):
prompt = prompt[:-len('. '+style_prompt)].rstrip()
prompt_arr = [trigger_word, prompt, style_prompt]
prompt = '. '.join([element.strip() for element in prompt_arr if element.strip() != ''])
try:
job = client_custom_model.submit(
custom_model=custom_model,
weight_name=weight_name,
prompt=prompt,
inf_steps=inf_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
seed=seed,
lora_weight=lora_weight,
api_name="/infer"
)
result = job.result()
except ValueError as e:
raise gr.Error(e)
generated_image_path = result[0] # Esto puede necesitar ser ajustado basado en la estructura real de result
used_seed = result[1] # Esto puede necesitar ser ajustado basado en la estructura real de result
used_prompt = prompt # El prompt usado es simplemente el prompt procesado
generated_image_path = load_img(generated_image_path, output_type="str")
return generated_image_path, used_seed, used_prompt
except Exception as e:
gr.Warning("Error: " + str(e))
def cancel_infer():
global job
if job:
job.cancel()
return "Job has been cancelled"
return "No job to cancel"
def update_selection(evt: gr.SelectData):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index
)
def resize_image(image_path, reduction_factor):
image = Image.open(image_path)
width, height = image.size
new_size = (width // reduction_factor, height // reduction_factor)
resized_image = image.resize(new_size)
return resized_image
def save_image(image):
unique_filename = f"resized_image_{uuid.uuid4().hex}.png"
image.save(unique_filename)
return unique_filename
def upscale_image(image, resolution, num_inference_steps, strength, hdr, guidance_scale, controlnet_strength, scheduler_name, reduce_factor):
global client_tile_upscaler
image = image[1]
try:
client_tile_upscaler = Client(tile_upscaler_url)
print(f"Loaded custom model from {tile_upscaler_url}")
except ValueError as e:
print(f"Failed to load custom model: {e}")
client_tile_upscaler = None
raise gr.Error("Failed to load client for " + tile_upscaler_url)
if (reduce_factor > 1):
image = resize_image(image, reduce_factor)
image = save_image(image)
try:
job = client_tile_upscaler.submit(
param_0=image,
param_1=resolution,
param_2=num_inference_steps,
param_3=strength,
param_4=hdr,
param_5=guidance_scale,
param_6=controlnet_strength,
param_7=scheduler_name,
api_name="/wrapper"
)
result = job.result()
except ValueError as e:
raise gr.Error(e)
return [image, result]
def refine_image(apply_refiner, image, model ,prompt, negative_prompt, num_inference_steps, guidance_scale, seed, strength):
if (not apply_refiner):
return image
client = InferenceClient()
refined_image = client.image_to_image(
handle_file(image),
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
model=model,
strength=strength
)
return refined_image
css="""
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# lichorosario LoRA Portfolio")
gr.Markdown(
"### This is my portfolio.\n"
"**Note**: Generation quality may vary. For best results, adjust the parameters.\n"
"Special thanks to [@artificialguybr](https://huggingface.co/artificialguybr) and [@fffiloni](https://huggingface.co/fffiloni).\n"
"Based on [https://huggingface.co/spaces/fffiloni/sd-xl-custom-model](https://huggingface.co/spaces/fffiloni/sd-xl-custom-model) and [https://huggingface.co/spaces/gokaygokay/TileUpscalerV2](https://huggingface.co/spaces/gokaygokay/TileUpscalerV2)"
)
with gr.Row():
with gr.Column(scale=2):
prompt_in = gr.Textbox(
label="Your Prompt",
info="Don't forget to include your trigger word if necessary"
)
style_prompt_in = gr.Textbox(
label="Your Style Prompt"
)
selected_info = gr.Markdown("")
used_prompt = gr.Textbox(
label="Used prompt"
)
with gr.Column(elem_id="col-container"):
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
inf_steps = gr.Slider(
label="Inference steps",
minimum=3,
maximum=150,
step=1,
value=25
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=50.0,
step=0.1,
value=12
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=3072,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=3072,
step=32,
value=512,
)
examples = [
[1024,512],
[2048,512],
[3072, 512]
]
gr.Examples(
label="Presets",
examples=examples,
inputs=[width, height],
outputs=[]
)
with gr.Row():
seed = gr.Slider(
label="Seed",
info="-1 denotes a random seed",
minimum=-1,
maximum=423538377342,
step=1,
value=-1
)
last_used_seed = gr.Number(
label="Last used seed",
info="the seed used in the last generation",
)
lora_weight = gr.Slider(
label="LoRa weight",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0
)
with gr.Group():
apply_refiner = gr.Checkbox(label="Apply refiner", value=False)
with gr.Accordion("Refiner params", open=False) as refiner_params:
refiner_prompt = gr.Textbox(lines=3, label="Prompt")
refiner_negative_prompt = gr.Textbox(lines=3, label="Negative Prompt")
refiner_strength = gr.Slider(
label="Strength",
minimum=0,
maximum=300,
step=0.01,
value=1
)
refiner_num_inference_steps = gr.Slider(
label="Inference steps",
minimum=3,
maximum=300,
step=1,
value=25
)
refiner_guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=50.0,
step=0.1,
value=12
)
refiner_seed = gr.Slider(
label="Seed",
info="-1 denotes a random seed",
minimum=-1,
maximum=423538377342,
step=1,
value=-1
)
refiner_model = gr.Textbox(label="Model", value="stabilityai/stable-diffusion-xl-refiner-1.0")
apply_refiner.change(
fn=lambda x: gr.update(visible=x),
inputs=apply_refiner,
outputs=refiner_params,
queue=False,
api_name=False,
)
with gr.Column(scale=1):
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=2,
height="100%"
)
submit_btn = gr.Button("Submit")
cancel_btn = gr.Button("Cancel")
with gr.Row():
with gr.Column():
generated_image = gr.Image(label="Image / Refined Image")
enhace_button = gr.Button("Enhance Image")
with gr.Column():
output_slider = ImageSlider(label="Before / After", show_download_button=False)
with gr.Accordion("Enhacer params", open=False):
upscale_reduce_factor = gr.Slider(minimum=1, maximum=10, step=1, label="Reduce Factor", info="1/n")
upscale_resolution = gr.Slider(minimum=128, maximum=2048, value=1024, step=128, label="Resolution", info="Image width")
upscale_num_inference_steps = gr.Slider(minimum=1, maximum=150, value=50, step=1, label="Number of Inference Steps")
upscale_strength = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.01, label="Strength", info="Higher values give more detail")
upscale_hdr = gr.Slider(minimum=0, maximum=1, value=0, step=0.1, label="HDR Effect")
upscale_guidance_scale = gr.Slider(minimum=0, maximum=20, value=12, step=0.5, label="Guidance Scale")
upscale_controlnet_strength = gr.Slider(minimum=0.0, maximum=2.0, value=0.75, step=0.05, label="ControlNet Strength")
upscale_scheduler_name = gr.Dropdown(
choices=["DDIM", "DPM++ 3M SDE Karras", "DPM++ 3M Karras"],
value="DDIM",
label="Scheduler"
)
selected_index = gr.State(None)
submit_btn.click(
fn=infer,
inputs=[selected_index, prompt_in, style_prompt_in, inf_steps, guidance_scale, width, height, seed, lora_weight],
outputs=[generated_image, last_used_seed, used_prompt]
).then(refine_image,
[apply_refiner, generated_image, refiner_model, refiner_prompt, refiner_negative_prompt, refiner_num_inference_steps, refiner_guidance_scale, refiner_seed, refiner_strength],
generated_image
)
cancel_btn.click(
fn=cancel_infer,
outputs=[]
)
def clear_output(image_slider):
return None
enhace_button.click(
fn=clear_output,
inputs=[output_slider],
outputs=[output_slider]
).then(
upscale_image,
[generated_image, upscale_resolution, upscale_num_inference_steps, upscale_strength, upscale_hdr, upscale_guidance_scale, upscale_controlnet_strength, upscale_scheduler_name, upscale_reduce_factor],
output_slider
)
gallery.select(update_selection, outputs=[prompt_in, selected_info, selected_index])
demo.launch(show_error=True)