HiCo_T2I / app.py
boomcheng's picture
Update app.py
5900058 verified
raw
history blame
6.61 kB
import gradio as gr
import numpy as np
from PIL import Image
import torch
from diffusers import ControlNetModel, UniPCMultistepScheduler
from hico_pipeline import StableDiffusionControlNetMultiLayoutPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize model
controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
"krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
)
pipe = pipe.to(device)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
MAX_SEED = np.iinfo(np.int32).max
# Store objects
object_classes_list = ["A photograph of a young woman wrapped in a towel wearing a pair of sunglasses", "a towel", "a young woman wrapped in a towel wearing a pair of sunglasses", "a pair of sunglasses"]
object_bboxes_list = ["0,0,512,512", "17,77,144,155", "16,28,157,155", "82,44,129,63"]
# Function to add or update the prompt in the list
def submit_prompt(prompt):
if object_classes_list:
object_classes_list[0] = prompt # Overwrite the first element if it exists
else:
object_classes_list.insert(0, prompt) # Add to the beginning if the list is empty
if not object_bboxes_list:
object_bboxes_list.insert(0, "0,0,512,512") # Add the default bounding box if the list is empty
combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
return combined_list, gr.update(interactive=False) # Make the prompt input non-editable
# Function to add a new object with validation
def add_object(object_class, bbox):
try:
# Split and convert bbox string into integers
x1, y1, x2, y2 = map(int, bbox.split(","))
# Validate the coordinates
if x2 < x1 or y2 < y1:
return "Error: x2 cannot be less than x1 and y2 cannot be less than y1.", []
if x1 < 0 or y1 < 0 or x2 > 512 or y2 > 512:
return "Error: Coordinates must be between 0 and 512.", []
# If validation passes, add to the lists
object_classes_list.append(object_class)
object_bboxes_list.append(bbox)
combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
return combined_list
except ValueError:
return "Error: Invalid input format. Use x1,y1,x2,y2.", []
# Function to generate images based on added objects
def generate_image(prompt, guidance_scale, num_inference_steps, randomize_seed, seed):
img_width, img_height = 512, 512
r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
list_cond_image = []
for bbox in object_bboxes_list:
x1, y1, x2, y2 = map(int, bbox.split(","))
cond_image = np.zeros_like(r_image, dtype=np.uint8)
cond_image[y1:y2, x1:x2] = 255
list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))
if randomize_seed or seed is None:
seed = np.random.randint(0, MAX_SEED)
generator = torch.manual_seed(seed)
image = pipe(
prompt=prompt,
layo_prompt=object_classes_list,
guess_mode=False,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
image=list_cond_image,
fuse_type="avg",
width=512,
height=512
).images[0]
print(type(image),'image')
return image, seed
# Function to clear all arrays and reset the UI
def clear_arrays():
object_classes_list.clear()
object_bboxes_list.clear()
return [], gr.update(value="", interactive=True) # Clear the objects and reset the prompt
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# HiCo_T2I 512px")
gr.Markdown(" You can directly click **Generate Image** or customize it by first entering the global caption, followed by subcaptions and their corresponding coordinates.")
# Put prompt and submit button in the same row
with gr.Group():
with gr.Row():
# Prompt input and submit button
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt here",
container=False,
)
submit_button = gr.Button("Submit Prompt", scale=0)
# Always visible DataFrame
objects_display = gr.Dataframe(
headers=["Caption", "Bounding Box"],
value=[[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
)
with gr.Row():
object_class_input = gr.Textbox(label="Sub-caption", placeholder="Enter Sub-caption (e.g., apple)")
bbox_input = gr.Textbox(label="Bounding Box (x1,y1,x2,y2 and >=0 and <=512)", placeholder="Enter bounding box coordinates")
add_button = gr.Button("Add")
# Advanced settings in a collapsible accordion
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.5
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=50
)
generate_button = gr.Button("Generate Image")
result = gr.Image(label="Generated Image")
# Refresh button to clear arrays and reset inputs (moved below the result)
refresh_button = gr.Button("Refresh")
# Submit the prompt and update the display
submit_button.click(
fn=submit_prompt,
inputs=prompt,
outputs=[objects_display, prompt]
)
# Add object and update display
add_button.click(
fn=add_object,
inputs=[object_class_input, bbox_input],
outputs=[objects_display]
)
# Generate image based on added objects
generate_button.click(
fn=generate_image,
inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed],
outputs=[result, seed]
)
# Refresh button to clear arrays and reset inputs
refresh_button.click(
fn=clear_arrays,
inputs=None,
outputs=[objects_display, prompt]
)
if __name__ == "__main__":
demo.launch()