| | import base64 |
| | from dataclasses import dataclass |
| | from io import BytesIO |
| | from pathlib import Path |
| | from typing import Literal, cast |
| |
|
| | import gradio as gr |
| | import jinja2 |
| | from openai import OpenAI |
| | from PIL import Image |
| | from pydantic import BaseModel |
| |
|
| | client = OpenAI() |
| |
|
| | TEMPLATES_DIR = Path(__file__).resolve().parent / "templates" |
| | jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(str(TEMPLATES_DIR))) |
| |
|
| | SYSTEM_PROMPT = "You are expert prompt engineer" |
| |
|
| | StyleName = Literal[ |
| | "General", |
| | "Fashion", |
| | "Emotional Lifestyle", |
| | "Extreme Sports", |
| | "Captivating", |
| | "Image Replication", |
| | "Red Bar Lighting", |
| | "Teal Noir", |
| | ] |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class StyleDefinition: |
| | name: StyleName |
| | template_filename: str |
| | info: str |
| |
|
| |
|
| | STYLE_DEFINITIONS: dict[StyleName, StyleDefinition] = { |
| | "General": StyleDefinition( |
| | name="General", |
| | template_filename="general_prompt.jinja", |
| | info="Versatile, balanced storytelling with cinematic detail for most scenarios.", |
| | ), |
| | "Fashion": StyleDefinition( |
| | name="Fashion", |
| | template_filename="fashion_prompt.jinja", |
| | info="Editorial fashion aesthetic highlighting garments, styling, and runway polish.", |
| | ), |
| | "Emotional Lifestyle": StyleDefinition( |
| | name="Emotional Lifestyle", |
| | template_filename="emotional_lifestyle_prompt.jinja", |
| | info="Warm, candid lifestyle imagery that focuses on mood, relationships, and feelings.", |
| | ), |
| | "Extreme Sports": StyleDefinition( |
| | name="Extreme Sports", |
| | template_filename="extreme_sports_prompt.jinja", |
| | info="High-adrenaline action shots that emphasize energy, motion, and athletic feats.", |
| | ), |
| | "Captivating": StyleDefinition( |
| | name="Captivating", |
| | template_filename="captivating_prompt.jinja", |
| | info="Visually striking compositions with dramatic flair and memorable storytelling.", |
| | ), |
| | "Image Replication": StyleDefinition( |
| | name="Image Replication", |
| | template_filename="image_replication_prompt.jinja", |
| | info=( |
| | "Mimic the reference image's composition, lighting, and styling exactly while" |
| | " inserting the user or their face in place of the original subject. Eg. If the reference image is a music album cover, the user's face will be embedded in the album cover." |
| | ), |
| | ), |
| | "Red Bar Lighting": StyleDefinition( |
| | name="Red Bar Lighting", |
| | template_filename="red_bar_lighting_prompt.jinja", |
| | info="Red bar lighting style for image generation.", |
| | ), |
| | "Teal Noir": StyleDefinition( |
| | name="Teal Noir", |
| | template_filename="teal_noir_prompt.jinja", |
| | info="Teal noir style for image generation.", |
| | ) |
| | } |
| |
|
| | PROMPT_TEMPLATES = { |
| | style: jinja_env.get_template(config.template_filename) |
| | for style, config in STYLE_DEFINITIONS.items() |
| | } |
| |
|
| | DEFAULT_STYLE: StyleName = "General" |
| | STYLE_CHOICES: tuple[StyleName, ...] = tuple(STYLE_DEFINITIONS.keys()) |
| |
|
| | STYLE_INFORMATION_BLOCK = "\n".join( |
| | f"- {style}: {config.info}" for style, config in STYLE_DEFINITIONS.items() |
| | ) |
| |
|
| |
|
| | class StyleSelectionResponse(BaseModel): |
| | style: StyleName |
| | |
| |
|
| | def process_prompt(user_image, reference_image, target_label: str, user_prompt: str, style: StyleName) -> str: |
| | user_image_url = None |
| | reference_image_url = None |
| |
|
| | if user_image is not None: |
| | buffer = BytesIO() |
| | user_image.convert("RGB").save(buffer, format="JPEG", quality=90) |
| | b64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| | user_image_url = f"data:image/jpeg;base64,{b64_image}" |
| |
|
| | if reference_image is not None: |
| | buffer = BytesIO() |
| | reference_image.convert("RGB").save(buffer, format="JPEG", quality=90) |
| | b64_reference_image = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| | reference_image_url = f"data:image/jpeg;base64,{b64_reference_image}" |
| |
|
| | try: |
| | template = PROMPT_TEMPLATES[style] |
| | except KeyError as error: |
| | raise ValueError(f"Unsupported style: {style}") from error |
| |
|
| | user_content = template.render(user_prompt=user_prompt) |
| |
|
| | content = [{"type": "input_text", "text": user_content}] |
| |
|
| | if user_image_url is not None: |
| | content.append({"type": "input_image", "image_url": user_image_url}) |
| | if reference_image_url is not None: |
| | content.append({"type": "input_image", "image_url": reference_image_url}) |
| |
|
| | response = client.responses.create( |
| | model="gpt-5", |
| | reasoning={"effort": "minimal"}, |
| | input=[ |
| | { |
| | "role": "system", |
| | "content": SYSTEM_PROMPT, |
| | }, |
| | { |
| | "role": "user", |
| | "content": content, |
| | } |
| | ], |
| | ) |
| | return f"{response.output_text} {target_label.strip()}" |
| |
|
| |
|
| | def recommend_style(user_prompt: str, reference_image: Image.Image | None) -> StyleSelectionResponse: |
| | if reference_image is not None: |
| | buffer = BytesIO() |
| | reference_image.convert("RGB").save(buffer, format="JPEG", quality=90) |
| | b64_reference_image = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| | reference_image_url = f"data:image/jpeg;base64,{b64_reference_image}" |
| | else: |
| | reference_image_url = None |
| |
|
| |
|
| | user_prompt = f"""You are an art director who must pick the most fitting style name for a user's prompt. |
| | Consider the available styles and choose the single best option. User has provided the reference image. |
| | |
| | Style Guide: |
| | {STYLE_INFORMATION_BLOCK} |
| | |
| | User Prompt: |
| | {user_prompt} |
| | """ |
| | content = [{"type": "input_text", "text": user_prompt}] |
| | if reference_image_url is not None: |
| | content.append({ |
| | "type": "input_image", "image_url": reference_image_url |
| | }) |
| | completion = client.responses.parse( |
| | model="gpt-5-mini", |
| | reasoning={"effort": "low"}, |
| | input=[{ |
| | "role": "user", |
| | "content": content, |
| | }], |
| | text_format=StyleSelectionResponse, |
| | ) |
| | return completion.output_parsed.style |
| |
|
| |
|
| | def handle_auto_style_toggle(auto_enabled: bool) -> dict[str, object]: |
| | return gr.update(interactive=not auto_enabled) |
| |
|
| |
|
| | def generate_prompt_handler( |
| | user_image, |
| | reference_image, |
| | target_label: str, |
| | user_prompt: str, |
| | current_style: str | None, |
| | auto_style_enabled: bool, |
| | ): |
| |
|
| | if auto_style_enabled: |
| | current_style = recommend_style(user_prompt, reference_image) |
| |
|
| | prompt_text = process_prompt( |
| | user_image=user_image, |
| | reference_image=reference_image, |
| | target_label=target_label, |
| | user_prompt=user_prompt, |
| | style=current_style, |
| | ) |
| | display_text = f"Selected style: {current_style}\n\n{prompt_text}" |
| | return display_text, gr.update(value=current_style, interactive=False) |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Row(): |
| | with gr.Column(): |
| | user_image = gr.Image( |
| | label="Upload user photo", |
| | type="pil" |
| | ) |
| | reference_image = gr.Image( |
| | label="Optional: Upload reference image (Eg. movie poster, music album cover, etc.)", |
| | type="pil", |
| | ) |
| | target_label = gr.Textbox( |
| | label="Enter target label", |
| | placeholder="SMRA", |
| | ) |
| | user_prompt = gr.Textbox( |
| | label="Enter your prompt", |
| | placeholder="picture of me while sitting in a chair in the ocean", |
| | lines=4, |
| | ) |
| | style_dropdown = gr.Dropdown( |
| | choices=list(STYLE_CHOICES), |
| | value=DEFAULT_STYLE, |
| | label="Style Selection", |
| | info="Choose the visual style for your enhanced prompt", |
| | interactive=False, |
| | ) |
| | auto_style_checkbox = gr.Checkbox( |
| | label="Auto-select best style", |
| | value=True, |
| | ) |
| | generate_button = gr.Button("Generate Prompt") |
| | with gr.Column(): |
| | prompt_output = gr.Textbox( |
| | label="Style Prompt", |
| | lines=20, |
| | ) |
| |
|
| | generate_button.click( |
| | generate_prompt_handler, |
| | inputs=[ |
| | user_image, |
| | reference_image, |
| | target_label, |
| | user_prompt, |
| | style_dropdown, |
| | auto_style_checkbox, |
| | ], |
| | outputs=[prompt_output, style_dropdown], |
| | ) |
| | auto_style_checkbox.change( |
| | handle_auto_style_toggle, |
| | inputs=[auto_style_checkbox], |
| | outputs=[style_dropdown], |
| | ) |
| |
|
| | demo.launch() |
| |
|