POET / app.py
xh365's picture
update personalization
489c70c
import gradio as gr
import numpy as np
import random
import spaces
import torch
import re
import transformers
import open_clip
from optim_utils import optimize_prompt
from utils import (
clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
get_refine_msg, clean_cache, get_personalize_message, get_personalized_simplified,
clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS
)
# =========================
# Constants / Defaults
# =========================
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES = 4
MAX_ROUND = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()
selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
llm_pipe = None
inverted_prompt = ""
torch.cuda.empty_cache()
METHOD = "Experimental"
counter = 1
enable_submit = False
redesign_flag = False
responses_memory = {METHOD: {}}
example_data = [
[
PROMPTS["Tourist promotion"],
IMAGES["Tourist promotion"]["ours"]
],
[
PROMPTS["Fictional character generation"],
IMAGES["Fictional character generation"]["ours"]
],
[
PROMPTS["Interior Design"],
IMAGES["Interior Design"]["ours"]
],
]
# =========================
# Image Generation Helpers
# =========================
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=True,
width=256,
height=256,
guidance_scale=5,
num_inference_steps=18,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
with torch.no_grad():
image = selected_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_refine_msg(prompt, num_prompts)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
prompt_list = clean_response_gpt(outputs)
return prompt_list
def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
# messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
messages = get_personalized_simplified(prompt, like_image, dislike_image)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
return outputs
@spaces.GPU(duration=100)
def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2):
global inverted_prompt
text_params = {
"iter": iter,
"lr": lr,
"batch_size": batch_size,
"prompt_len": prompt_len,
"weight_decay": 0.1,
"prompt_bs": 1,
"loss_weight": 1.0,
"print_step": 100,
"clip_model": CLIP_MODEL,
"clip_pretrain": PRETRAINED_CLIP,
}
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
# =========================
# UI Helper Functions
# =========================
# Store generated images for selection
current_generated_images = []
def reset_gallery():
return []
def display_error_message(msg, duration=5):
gr.Warning(msg, duration=duration)
def display_info_message(msg, duration=5):
gr.Info(msg, duration=duration)
def check_evaluation(sim_radio, like_image, dislike_image):
if not sim_radio or not like_image or not dislike_image:
display_error_message("❌ Please fill all evaluations before changing image or submitting.")
return False
return True
def generate_image(prompt, like_image, dislike_image):
global responses_memory, current_generated_images
history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
print(feedback, like_image, dislike_image)
if like_image and dislike_image and feedback:
personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
else:
personalized = prompt
gallery_images = []
current_generated_images = [] # Reset the stored images
refined_prompts = call_gpt_refine_prompt(personalized)
for i in range(NUM_IMAGES):
img = infer(refined_prompts[i])
gallery_images.append(img)
current_generated_images.append(img) # Store for selection
yield gallery_images
def on_gallery_select(evt: gr.SelectData):
"""Handle gallery image selection and return the selected image"""
global current_generated_images
if current_generated_images and evt.index < len(current_generated_images):
return current_generated_images[evt.index]
return None
def handle_like_drag(selected_image):
"""Handle setting an image as liked"""
return selected_image
def handle_dislike_drag(selected_image):
"""Handle setting an image as disliked"""
return selected_image
def redesign(prompt, sim_radio, current_images, history_images, like_image, dislike_image):
global counter, responses_memory, redesign_flag
if check_evaluation(sim_radio, like_image, dislike_image):
responses_memory[METHOD][counter] = {
"prompt": prompt,
"sim_radio": sim_radio,
"response": "",
"satisfied_img": f"round {counter}, liked image",
"unsatisfied_img": f"round {counter}, disliked image",
}
history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
# Update history images
if not history_images:
history_images = current_images.copy() if current_images else []
elif current_images:
history_images.extend(current_images)
current_images = []
examples_state = gr.update(samples=history_prompts, visible=True)
prompt_state = gr.update(interactive=True)
next_state = gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
counter += 1
redesign_flag = True
display_info_message(f"βœ… Round {counter-1} feedback saved! You can continue redesigning or restart.")
return None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state
else:
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
def save_response(prompt, sim_radio, like_image, dislike_image):
global counter, responses_memory, redesign_flag, current_generated_images
# Reset all global variables
responses_memory[METHOD] = {}
counter = 1
redesign_flag = False
current_generated_images = []
# Reset UI states
prompt_state = gr.update(value="", interactive=True)
next_state = gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False)
submit_state = gr.update(interactive=False)
sim_radio_state = gr.update(value=None)
like_image_state = gr.update(value=None)
dislike_image_state = gr.update(value=None)
gallery_state = []
history_gallery_state = []
examples_state = gr.update(samples=[['']], visible=True)
display_info_message("πŸ”„ Session restarted! You can begin with a new prompt.")
return (sim_radio_state, prompt_state, next_state, redesign_state,
like_image_state, dislike_image_state, gallery_state, history_gallery_state, examples_state)
# =========================
# Interface (single tab, no participant/scenario/background)
# =========================
css = """
#col-container {
margin: 0 auto;
max-width: 700px;
}
#col-container2 {
margin: 0 auto;
max-width: 1000px;
}
#col-container3 {
margin: 0 0 auto auto;
max-width: 300px;
}
#button-container {
display: flex;
justify-content: center;
gap: 10px;
}
#compact-compact-row {
width:100%;
max-width: 800px;
margin: 0px auto;
}
#compact-row {
width:100%;
max-width: 1000px;
margin: 0px auto;
}
.header-section {
text-align: center;
margin-bottom: 2rem;
}
.abstract-text {
text-align: justify;
line-height: 1.5;
margin: 0rem 0;
padding: 0 0.5rem;
background-color: rgba(0, 0, 0, 0.05);
border-radius: 8px;
border-left: 4px solid #3498db;
}
.paper-link {
display: inline-block;
margin: 0rem 0;
padding: 0rem 0rem;
background-color: #3498db;
color: white;
text-decoration: none;
border-radius: 5px;
font-weight: 500;
}
.paper-link:hover {
background-color: #2980b9;
text-decoration: none;
}
.authors-section {
text-align: center;
margin: 0 0;
font-style: italic;
color: #666;
}
.authors-title {
font-weight: bold;
margin-bottom: 0rem;
color: #333;
}
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 60px;
width: auto;
max-width: 150px;
display: inline-block;
}
.instruction-box {
background: linear-gradient(135deg, #e8f4fd 0%, #f0f8ff 100%);
border: 2px solid #3498db;
border-radius: 12px;
padding: 20px;
margin: 15px 0;
color: #2c3e50;
}
.instruction-title {
font-size: 1.2em;
font-weight: bold;
margin-bottom: 15px;
color: #2c3e50;
display: flex;
align-items: center;
gap: 8px;
}
.step-list {
list-style: none;
padding: 0;
margin: 0;
}
.step-item {
background: rgba(52, 152, 219, 0.1);
border-radius: 8px;
padding: 12px 16px;
margin: 8px 0;
border-left: 4px solid #3498db;
}
.step-number {
font-weight: bold;
color: #3498db;
margin-right: 8px;
}
.personalization-header {
background: linear-gradient(135deg, #ff6b6b, #ee5a24);
color: white;
padding: 15px;
border-radius: 10px 10px 0 0;
margin: -10px -10px 15px -10px;
text-align: center;
font-weight: bold;
font-size: 1.1em;
}
"""
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
# State variable to hold selected image
selected_image = gr.State(None)
with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
gr.HTML('<div class="logo-container"><img src="https://huggingface.co/spaces/PAI-GEN/POET/resolve/main/images/icon.png" alt="POET Logo"></div>')
gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
# Paper Link
gr.HTML("""
<div style="text-align: center;">
<a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
πŸ“„ Read the Full Paper
</a>
</div>
""")
gr.Markdown("""
<div class="abstract-text">
<strong>Abstract:</strong> Given that creative end-users often operate in diverse, context-specific ways that are often unpredictable, more variation and personalization are necessary. We introduce POET, a real-time interactive tool that (1) automatically discovers dimensions of homogeneity in text-to-image generative models, (2) expands these dimensions to diversify the output space of generated images, and (3) learns from user feedback to personalize expansions. Focusing on visual creativity, POET offers a first glimpse of how interaction techniques of future text-to-image generation tools may support and align with more pluralistic values and the needs of end-users during the ideation stages of their work.
</div>
""", elem_classes=["abstract-text"])
gr.Markdown("""
<div class="authors-section">
<a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>,
<a href="https://www.aliceqian.com/">Alice Qian Zhang</a>,
<a href="https://haiyizhu.com/">Haiyi Zhu</a>,
<a href="https://www.andrew.cmu.edu/user/hongs/">Hong Shen</a>,
<a href="https://pliang279.github.io/">Paul Pu Liang</a>,
<a href="https://janeon.github.io/">Jane Hsieh</a>
</div>
""", elem_classes=["authors-section"])
with gr.Tab(""):
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
with gr.Row():
prompt = gr.Textbox(
label="🎨 Prompt",
max_lines=5,
placeholder="Enter your prompt",
visible=True,
)
with gr.Column(elem_id="col-container3"):
next_btn = gr.Button("Generate", variant="primary", scale=1)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method = gr.Gallery(
label="Generated Images (Click to select, then set to Like/Dislike image)",
columns=[4],
rows=[1],
height=400,
interactive=False,
elem_id="gallery",
format="png"
)
with gr.Column(elem_id="col-container3"):
like_btn = gr.Button("πŸ‘ Set as Liked (Optional for personalization)", size="sm", variant="secondary")
like_image = gr.Image(
label="Satisfied Image",
width=150,
height=150,
interactive=False,
format="png",
type="filepath"
)
dislike_btn = gr.Button("πŸ‘Ž Set as Disliked (Optional for personalization)", size="sm", variant="secondary")
dislike_image = gr.Image(
label="Unsatisfied Image",
width=150,
height=150,
interactive=False,
format="png",
type="filepath"
)
with gr.Accordion("🎯 Advanced: Personalized Image Redesign", open=False, elem_id="col-container2"):
gr.HTML("""
<div class="instruction-box">
<div class="instruction-title">
πŸ“‹ How to Use Personalized Redesign
</div>
<div class="step-list">
<div class="step-item">
<span class="step-number">1.</span>
<strong>Rate Your Satisfaction:</strong> Provide a satisfaction score for the current generated images
</div>
<div class="step-item">
<span class="step-number">2.</span>
<strong>Select Preferences:</strong> Choose your most liked and disliked images
</div>
<div class="step-item">
<span class="step-number">3.</span>
<strong>Save & Iterate:</strong> Click "Save Personalized Data" before redesgining your prompt and clicking "Generate"
</div>
<div class="step-item">
<span class="step-number">4.</span>
<strong>Restart Anytime:</strong> Use the "Restart" button to begin a fresh session
</div>
</div>
</div>
""")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“Š Rate Current Generation")
with gr.Row():
sim_radio = gr.Radio(
OPTIONS,
label="How satisfied are you with the current generated images?",
type="value",
show_label=True,
container=True,
scale=1
)
with gr.Row(elem_id="button-container"):
with gr.Column(scale=1):
redesign_btn = gr.Button("πŸ’Ύ Save Personalization Data", variant="primary", size="lg")
with gr.Column(scale=1):
submit_btn = gr.Button("πŸ”„ Restart Session", variant="secondary", size="lg")
with gr.Column(elem_id="col-container2"):
example = gr.Examples([['']], prompt, label="πŸ“ Prompt History", visible=True)
history_images = gr.Gallery(
label="πŸ—ƒοΈ Generation History",
columns=[4],
rows=[1],
elem_id="gallery",
format="png",
interactive=False,
)
with gr.Column(elem_id="col-container2"):
gr.Markdown("### 🌟 Examples")
ex1 = gr.Image(label="Image 1", width=200, height=200, format="png", type="filepath", visible=False)
ex2 = gr.Image(label="Image 2", width=200, height=200, format="png", type="filepath", visible=False)
ex3 = gr.Image(label="Image 3", width=200, height=200, format="png", type="filepath", visible=False)
ex4 = gr.Image(label="Image 4", width=200, height=200, format="png", type="filepath", visible=False)
gr.Examples(
examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
inputs=[prompt, ex1, ex2, ex3, ex4]
)
# =========================
# Wiring
# =========================
# Handle gallery selection
images_method.select(
fn=on_gallery_select,
inputs=[],
outputs=[selected_image]
)
# Handle like/dislike button clicks
like_btn.click(
fn=handle_like_drag,
inputs=[selected_image],
outputs=[like_image]
)
dislike_btn.click(
fn=handle_dislike_drag,
inputs=[selected_image],
outputs=[dislike_image]
)
next_btn.click(
fn=generate_image,
inputs=[prompt, like_image, dislike_image],
outputs=[images_method]
).success(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)],
outputs=[next_btn, prompt, redesign_btn, submit_btn])
redesign_btn.click(
fn=redesign,
inputs=[prompt, sim_radio, images_method, history_images, like_image, dislike_image],
outputs=[sim_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn]
)
submit_btn.click(
fn=save_response,
inputs=[prompt, sim_radio, like_image, dislike_image],
outputs=[sim_radio, prompt, next_btn, redesign_btn, like_image, dislike_image, images_method, history_images, example.dataset]
)
if __name__ == "__main__":
demo.launch()