ugly-holiday-card-generator / pages /1_Part_A_-_Experimentation_Station.py
Mikiko Bazeley
Wrapping up documentation, copy
7c1b343
raw
history blame
19.8 kB
import streamlit as st
import cv2
import requests
from io import BytesIO
from PIL import Image, ImageDraw
import numpy as np
import os
from dotenv import load_dotenv
import zipfile
# Load environment variables
dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env')
load_dotenv(dotenv_path, override=True)
api_key = os.getenv("FIREWORKS_API_KEY")
if not api_key:
st.error("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
st.stop()
VALID_ASPECT_RATIOS = {
(1, 1): "1:1", (21, 9): "21:9", (16, 9): "16:9", (3, 2): "3:2", (5, 4): "5:4",
(4, 5): "4:5", (2, 3): "2:3", (9, 16): "9:16", (9, 21): "9:21",
}
def get_closest_aspect_ratio(width, height):
aspect_ratio = width / height
closest_ratio = min(VALID_ASPECT_RATIOS.keys(), key=lambda x: abs((x[0] / x[1]) - aspect_ratio))
return VALID_ASPECT_RATIOS[closest_ratio]
def process_image(uploaded_image):
image = np.array(Image.open(uploaded_image).convert('L'))
edges = cv2.Canny(image, 100, 200)
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
pil_image = Image.fromarray(edges_rgb)
byte_arr = BytesIO()
pil_image.save(byte_arr, format='JPEG')
byte_arr.seek(0)
return byte_arr, pil_image
def call_control_net_api(uploaded_image, prompt, control_mode=0, guidance_scale=3.5, num_inference_steps=30, seed=0, controlnet_conditioning_scale=1.0):
control_image, processed_image = process_image(uploaded_image)
files = {'control_image': ('control_image.jpg', control_image, 'image/jpeg')}
original_image = Image.open(uploaded_image)
width, height = original_image.size
aspect_ratio = get_closest_aspect_ratio(width, height)
data = {
'prompt': prompt,
'control_mode': control_mode,
'aspect_ratio': aspect_ratio,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
'seed': seed,
'controlnet_conditioning_scale': controlnet_conditioning_scale
}
headers = {
'accept': 'image/jpeg',
'authorization': f'Bearer {api_key}',
}
response = requests.post('https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-controlnet-union/control_net',
files=files, data=data, headers=headers)
if response.status_code == 200:
return Image.open(BytesIO(response.content)), processed_image, original_image
else:
st.error(f"Request failed with status code: {response.status_code}, Response: {response.text}")
return None, None, None
def draw_crop_preview(image, x, y, width, height):
draw = ImageDraw.Draw(image)
draw.rectangle([x, y, x + width, y + height], outline="red", width=2)
return image
# Fireworks Logo at the top
logo_image = Image.open("img/fireworksai_logo.png")
st.image(logo_image)
st.title("🎨 Holiday Card Generator - Part A: Design & Ideation 🎨")
st.markdown(
"""Welcome to the first part of your holiday card creation journey! 🌟 Here, you’ll play around with different styles, prompts, and parameters to design the perfect card border before adding a personal message in Part B. Let your creativity flow! 🎉
### How it works:
1. **🖼️ Upload Your Image:** Choose the image that will be the center of your card.
2. **✂️ Crop It:** Adjust the crop to highlight the most important part of your image.
3. **💡 Choose Your Style:** Select from festive border themes or input your own custom prompt to design something unique.
4. **⚙️ Fine-Tune Parameters:** Experiment with guidance scales, seeds, inference steps, and more for the perfect aesthetic.
5. **👀 Preview & Download:** See your generated holiday cards, tweak them until they’re just right, and download the final designs and metadata in a neat ZIP file!
Once you’ve got the perfect look, head over to **Part B** to add your personal message and finalize your holiday card! 💌
"""
)
# Add divider and subheader for uploading section
st.divider()
st.subheader("🖼️ Step 1: Upload Your Picture!")
st.markdown("""
- Click on the **Upload Image** button to select the image you want to feature on your holiday card.
- Accepted formats: **JPG** or **PNG**.
- Once uploaded, your image will appear on the screen—easy peasy!
""")
st.write("Upload the image that will be used for generating your holiday card.")
# Initialize session state variables if they don't exist
if 'uploaded_file' not in st.session_state:
st.session_state.uploaded_file = None
if 'card_params' not in st.session_state:
st.session_state.card_params = [{} for _ in range(4)] # Initialize for 4 cards
if 'generated_cards' not in st.session_state:
st.session_state.generated_cards = [None for _ in range(4)] # Store generated images and metadata
# File uploader - if a file is uploaded, save it in session_state
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
st.divider()
if uploaded_file is not None:
st.session_state.uploaded_file = uploaded_file # Save to session state
# Load the image from session state
if st.session_state.uploaded_file is not None:
original_image = Image.open(st.session_state.uploaded_file)
st.image(original_image, caption="Uploaded Image", use_column_width=True)
img_width, img_height = original_image.size
# Add a divider and subheader for the crop section
st.divider()
st.subheader("✂️ Step 2: Crop It Like It's Hot!")
st.markdown("""
- Adjust the **crop** sliders to select the perfect area of your image.
- This cropped section will be the centerpiece of your festive card. 🎨
- A preview will show you exactly how your crop looks before moving to the next step!
""")
st.write("Select the area you want to crop from the original image. This cropped portion will be used in the final card.")
col1, col2 = st.columns(2)
with col1:
x_pos = st.slider("X position (Left-Right)", 0, img_width, img_width // 4,
help="Move the slider to adjust the crop's left-right position.")
crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos),
help="Adjust the width of the crop.")
with col2:
y_pos = st.slider("Y position (Up-Down)", 0, img_height, img_height // 4,
help="Move the slider to adjust the crop's up-down position.")
crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos),
help="Adjust the height of the crop.")
preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height)
st.image(preview_image, caption="Crop Preview", use_column_width=True)
st.divider()
st.subheader("⚙️ Step 3: Set Your Festive Border Design with Flux + Fireworks!")
st.markdown("""
- Choose from a selection of holiday-themed borders like **snowflakes**, **Christmas lights**, or even **New Year's Eve fireworks**—all generated through **Flux models on Fireworks!** ✨
- Want to get creative? No problem—enter your own **custom prompt** to create a border that's as unique as you are, all powered by **Fireworks' inference API**.
""")
st.write("Customize the parameters for generating each holiday card. Each parameter will influence the final design and style of the card.")
# Define the list of suggested holiday prompts
holiday_prompts = [
"A border of Festive snowflakes and winter patterns for a holiday card border",
"A border of Joyful Christmas ornaments and lights decorating the edges",
"A border of Warm and cozy fireplace scene with stockings and garlands",
"A border of Colorful Hanukkah menorahs and dreidels along the border",
"A border of New Year's Eve fireworks with stars and confetti framing the image"
]
# Define input fields for each holiday card's parameters in expanders
for i in range(4):
with st.expander(f"Holiday Card {i + 1} Parameters"):
st.write(f"### Holiday Card {i + 1}")
# Get the card params from session_state if they exist
card_params = st.session_state.card_params[i]
card_params.setdefault("prompt", f"Custom Prompt {i + 1}")
card_params.setdefault("guidance_scale", 3.5)
card_params.setdefault("num_inference_steps", 30)
card_params.setdefault("seed", i * 100)
card_params.setdefault("controlnet_conditioning_scale", 0.5)
card_params.setdefault("control_mode", 0)
# Dropdown to choose a suggested holiday prompt or enter custom prompt
selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_prompts)
custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=card_params["prompt"]) if selected_prompt == "Custom" else selected_prompt
# Parameter sliders for each holiday card with tooltips
guidance_scale = st.slider(f"Guidance Scale for Holiday Card {i + 1}", min_value=0.0, max_value=20.0, value=card_params["guidance_scale"], step=0.1,
help="Adjusts how strongly the model follows the prompt. Higher values mean stronger adherence to the prompt.")
num_inference_steps = st.slider(f"Number of Inference Steps for Holiday Card {i + 1}", min_value=1, max_value=100, value=card_params["num_inference_steps"], step=1,
help="More inference steps can lead to a higher quality image but will take longer to generate.")
seed = st.slider(f"Random Seed for Holiday Card {i + 1}", min_value=0, max_value=1000, value=card_params["seed"],
help="The seed value allows you to recreate the same image each time, or explore variations by changing the seed.")
controlnet_conditioning_scale = st.slider(f"ControlNet Conditioning Scale for Holiday Card {i + 1}", min_value=0.0, max_value=1.0, value=card_params["controlnet_conditioning_scale"], step=0.1,
help="Controls how much the ControlNet input influences the output. A lower value reduces its influence.")
control_mode = st.slider(f"Control Mode for Holiday Card {i + 1}", min_value=0, max_value=2, value=card_params["control_mode"],
help="Choose how much ControlNet should influence the final image. 0: None, 1: Partial, 2: Full.")
# Update session state with the latest parameters
st.session_state.card_params[i] = {
"prompt": custom_prompt,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"seed": seed,
"controlnet_conditioning_scale": controlnet_conditioning_scale,
"control_mode": control_mode
}
st.divider() # Add a divider before the "Generate" button
# Generate the holiday cards
st.subheader("Just hit play!")
if st.button("Generate Holiday Cards"):
with st.spinner("Processing..."):
# Create a column layout for displaying cards side by side
col1, col2 = st.columns(2)
col3, col4 = st.columns(2)
columns = [col1, col2, col3, col4] # To display images in a 2x2 grid
# Store image bytes and metadata in lists
image_files = []
metadata = []
# Loop through each card's parameters and generate the holiday card
for i, params in enumerate(st.session_state.card_params):
prompt = params['prompt']
guidance_scale = params['guidance_scale']
num_inference_steps = params['num_inference_steps']
seed = params['seed']
controlnet_conditioning_scale = params['controlnet_conditioning_scale']
control_mode = params['control_mode']
# Generate the holiday card using the current parameters
generated_image, processed_image, _ = call_control_net_api(
st.session_state.uploaded_file, prompt, control_mode=control_mode,
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
seed=seed, controlnet_conditioning_scale=controlnet_conditioning_scale
)
if generated_image is not None:
# Resize generated_image to match original_image size
generated_image = generated_image.resize(original_image.size)
# Create a copy of the generated image
final_image = generated_image.copy()
# Crop the selected portion of the original image
cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height))
# Get the size of the cropped image
cropped_width, cropped_height = cropped_original.size
# Calculate the center of the generated image
center_x = (final_image.width - cropped_width) // 2
center_y = (final_image.height - cropped_height) // 2
# Paste the cropped portion of the original image onto the generated image at the calculated center
final_image.paste(cropped_original, (center_x, center_y))
# Save image to BytesIO
img_byte_arr = BytesIO()
final_image.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr))
# Store metadata
metadata.append({
"Card": f"Holiday Card {i + 1}",
"Prompt": prompt,
"Guidance Scale": guidance_scale,
"Inference Steps": num_inference_steps,
"Seed": seed,
"ControlNet Conditioning Scale": controlnet_conditioning_scale,
"Control Mode": control_mode
})
# Persist image and metadata in session state
st.session_state.generated_cards[i] = {
"image": final_image,
"metadata": metadata[-1]
}
# Display the final holiday card in one of the columns
columns[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True)
# Display the parameters used for this card
columns[i].write(f"**Prompt:** {prompt}")
columns[i].write(f"**Guidance Scale:** {guidance_scale}")
columns[i].write(f"**Inference Steps:** {num_inference_steps}")
columns[i].write(f"**Seed:** {seed}")
columns[i].write(f"**ControlNet Conditioning Scale:** {controlnet_conditioning_scale}")
columns[i].write(f"**Control Mode:** {control_mode}")
else:
st.error(f"Failed to generate holiday card {i + 1}. Please try again.")
# Create ZIP file for download
if image_files:
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zf:
# Add images to the ZIP file
for file_name, img_data in image_files:
zf.writestr(file_name, img_data.getvalue())
# Add metadata to the ZIP file
metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}\nControlNet Conditioning Scale: {m['ControlNet Conditioning Scale']}\nControl Mode: {m['Control Mode']}" for m in metadata])
zf.writestr("metadata.txt", metadata_str)
zip_buffer.seek(0)
st.subheader("Step 4: Download & Share Your Masterpiece! 📥")
st.markdown("""
- Once you're happy with your card, simply hit the download button to save your card and message as a **PNG** image.
- You can also view and download any **previously generated holiday cards** from this session!
""")
st.download_button(
label="Download all images and metadata as ZIP",
data=zip_buffer,
file_name="holiday_cards.zip",
mime="application/zip"
)
# Display previously generated cards if they exist and create a ZIP download button
st.divider()
st.subheader("Previously Generated Holiday Cards")
if any(st.session_state.generated_cards):
col1, col2 = st.columns(2)
col3, col4 = st.columns(2)
columns = [col1, col2, col3, col4] # Display in 2x2 grid
image_files = []
metadata = [] # Ensure metadata is initialized as a list
# Loop through previously generated cards
for i, card in enumerate(st.session_state.generated_cards):
if card and "metadata" in card: # Ensure the card exists and has metadata
columns[i].image(card['image'], caption=f"Holiday Card {i + 1} (Saved)")
card_metadata = card['metadata'] # Retrieve the metadata for the current card
columns[i].write(f"**Prompt:** {card_metadata['Prompt']}")
columns[i].write(f"**Guidance Scale:** {card_metadata['Guidance Scale']}")
columns[i].write(f"**Inference Steps:** {card_metadata['Inference Steps']}")
columns[i].write(f"**Seed:** {card_metadata['Seed']}")
columns[i].write(f"**ControlNet Conditioning Scale:** {card_metadata['ControlNet Conditioning Scale']}")
columns[i].write(f"**Control Mode:** {card_metadata['Control Mode']}")
# Add each image and its metadata to prepare for ZIP download
img_byte_arr = BytesIO()
card['image'].save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr))
metadata.append(card_metadata) # Append card's metadata to the metadata list
# Provide a ZIP download button for previously generated cards
if image_files:
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zf:
# Add images to the ZIP file
for file_name, img_data in image_files:
zf.writestr(file_name, img_data.getvalue())
# Add metadata to the ZIP file
metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}\nControlNet Conditioning Scale: {m['ControlNet Conditioning Scale']}\nControl Mode: {m['Control Mode']}" for m in metadata])
zf.writestr("metadata.txt", metadata_str)
# """ zip_buffer.seek(0)
# st.divider()
# st.subheader("Save images & metadata")
# st.download_button(
# label="Download all previously generated images and metadata as ZIP",
# data=zip_buffer,
# file_name="holiday_cards.zip",
# mime="application/zip"
# )
# """