import streamlit as st |
import cv2 |
import requests |
from io import BytesIO |
from PIL import Image, ImageDraw |
import numpy as np |
import os |
from dotenv import load_dotenv |
import zipfile |
dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env') |
load_dotenv(dotenv_path, override=True) |
api_key = os.getenv("FIREWORKS_API_KEY") |
if not api_key: |
st.error("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.") |
st.stop() |
(1, 1): "1:1", (21, 9): "21:9", (16, 9): "16:9", (3, 2): "3:2", (5, 4): "5:4", |
(4, 5): "4:5", (2, 3): "2:3", (9, 16): "9:16", (9, 21): "9:21", |
} |
def get_closest_aspect_ratio(width, height): |
aspect_ratio = width / height |
closest_ratio = min(VALID_ASPECT_RATIOS.keys(), key=lambda x: abs((x[0] / x[1]) - aspect_ratio)) |
return VALID_ASPECT_RATIOS[closest_ratio] |
def process_image(uploaded_image): |
image = np.array(Image.open(uploaded_image).convert('L')) |
edges = cv2.Canny(image, 100, 200) |
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) |
pil_image = Image.fromarray(edges_rgb) |
byte_arr = BytesIO() |
pil_image.save(byte_arr, format='JPEG') |
byte_arr.seek(0) |
return byte_arr, pil_image |
def call_control_net_api(uploaded_image, prompt, control_mode=0, guidance_scale=3.5, num_inference_steps=30, seed=0, controlnet_conditioning_scale=1.0): |
control_image, processed_image = process_image(uploaded_image) |
files = {'control_image': ('control_image.jpg', control_image, 'image/jpeg')} |
original_image = Image.open(uploaded_image) |
width, height = original_image.size |
aspect_ratio = get_closest_aspect_ratio(width, height) |
data = { |
'prompt': prompt, |
'control_mode': control_mode, |
'aspect_ratio': aspect_ratio, |
'guidance_scale': guidance_scale, |
'num_inference_steps': num_inference_steps, |
'seed': seed, |
'controlnet_conditioning_scale': controlnet_conditioning_scale |
} |
headers = { |
'accept': 'image/jpeg', |
'authorization': f'Bearer {api_key}', |
} |
response = requests.post('https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-controlnet-union/control_net', |
files=files, data=data, headers=headers) |
if response.status_code == 200: |
return Image.open(BytesIO(response.content)), processed_image, original_image |
else: |
st.error(f"Request failed with status code: {response.status_code}, Response: {response.text}") |
return None, None, None |
def draw_crop_preview(image, x, y, width, height): |
draw = ImageDraw.Draw(image) |
draw.rectangle([x, y, x + width, y + height], outline="red", width=2) |
return image |
logo_image = Image.open("img/fireworksai_logo.png") |
st.image(logo_image) |
st.title("🎨 Holiday Card Generator - Part A: Design & Ideation 🎨") |
st.markdown( |
"""Welcome to the first part of your holiday card creation journey! 🌟 Here, you’ll play around with different styles, prompts, and parameters to design the perfect card border before adding a personal message in Part B. Let your creativity flow! 🎉 |
### How it works: |
1. **🖼️ Upload Your Image:** Choose the image that will be the center of your card. |
2. **✂️ Crop It:** Adjust the crop to highlight the most important part of your image. |
3. **💡 Choose Your Style:** Select from festive border themes or input your own custom prompt to design something unique. |
4. **⚙️ Fine-Tune Parameters:** Experiment with guidance scales, seeds, inference steps, and more for the perfect aesthetic. |
5. **👀 Preview & Download:** See your generated holiday cards, tweak them until they’re just right, and download the final designs and metadata in a neat ZIP file! |
Once you’ve got the perfect look, head over to **Part B** to add your personal message and finalize your holiday card! 💌 |
""" |
) |
st.divider() |
st.subheader("🖼️ Step 1: Upload Your Picture!") |
st.markdown(""" |
- Click on the **Upload Image** button to select the image you want to feature on your holiday card. |
- Accepted formats: **JPG** or **PNG**. |
- Once uploaded, your image will appear on the screen—easy peasy! |
""") |
st.write("Upload the image that will be used for generating your holiday card.") |
if 'uploaded_file' not in st.session_state: |
st.session_state.uploaded_file = None |
if 'card_params' not in st.session_state: |
st.session_state.card_params = [{} for _ in range(4)] |
if 'generated_cards' not in st.session_state: |
st.session_state.generated_cards = [None for _ in range(4)] |
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) |
st.divider() |
if uploaded_file is not None: |
st.session_state.uploaded_file = uploaded_file |
if st.session_state.uploaded_file is not None: |
original_image = Image.open(st.session_state.uploaded_file) |
st.image(original_image, caption="Uploaded Image", use_column_width=True) |
img_width, img_height = original_image.size |
st.divider() |
st.subheader("✂️ Step 2: Crop It Like It's Hot!") |
st.markdown(""" |
- Adjust the **crop** sliders to select the perfect area of your image. |
- This cropped section will be the centerpiece of your festive card. 🎨 |
- A preview will show you exactly how your crop looks before moving to the next step! |
""") |
st.write("Select the area you want to crop from the original image. This cropped portion will be used in the final card.") |
col1, col2 = st.columns(2) |
with col1: |
x_pos = st.slider("X position (Left-Right)", 0, img_width, img_width // 4, |
help="Move the slider to adjust the crop's left-right position.") |
crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos), |
help="Adjust the width of the crop.") |
with col2: |
y_pos = st.slider("Y position (Up-Down)", 0, img_height, img_height // 4, |
help="Move the slider to adjust the crop's up-down position.") |
crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos), |
help="Adjust the height of the crop.") |
preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height) |
st.image(preview_image, caption="Crop Preview", use_column_width=True) |
st.divider() |
st.subheader("⚙️ Step 3: Set Your Festive Border Design with Flux + Fireworks!") |
st.markdown(""" |
- Choose from a selection of holiday-themed borders like **snowflakes**, **Christmas lights**, or even **New Year's Eve fireworks**—all generated through **Flux models on Fireworks!** ✨ |
- Want to get creative? No problem—enter your own **custom prompt** to create a border that's as unique as you are, all powered by **Fireworks' inference API**. |
""") |
st.write("Customize the parameters for generating each holiday card. Each parameter will influence the final design and style of the card.") |
holiday_prompts = [ |
"A border of Festive snowflakes and winter patterns for a holiday card border", |
"A border of Joyful Christmas ornaments and lights decorating the edges", |
"A border of Warm and cozy fireplace scene with stockings and garlands", |
"A border of Colorful Hanukkah menorahs and dreidels along the border", |
"A border of New Year's Eve fireworks with stars and confetti framing the image" |
] |
for i in range(4): |
with st.expander(f"Holiday Card {i + 1} Parameters"): |
st.write(f"### Holiday Card {i + 1}") |
card_params = st.session_state.card_params[i] |
card_params.setdefault("prompt", f"Custom Prompt {i + 1}") |
card_params.setdefault("guidance_scale", 3.5) |
card_params.setdefault("num_inference_steps", 30) |
card_params.setdefault("seed", i * 100) |
card_params.setdefault("controlnet_conditioning_scale", 0.5) |
card_params.setdefault("control_mode", 0) |
selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_prompts) |
custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=card_params["prompt"]) if selected_prompt == "Custom" else selected_prompt |
guidance_scale = st.slider(f"Guidance Scale for Holiday Card {i + 1}", min_value=0.0, max_value=20.0, value=card_params["guidance_scale"], step=0.1, |
help="Adjusts how strongly the model follows the prompt. Higher values mean stronger adherence to the prompt.") |
num_inference_steps = st.slider(f"Number of Inference Steps for Holiday Card {i + 1}", min_value=1, max_value=100, value=card_params["num_inference_steps"], step=1, |
help="More inference steps can lead to a higher quality image but will take longer to generate.") |
seed = st.slider(f"Random Seed for Holiday Card {i + 1}", min_value=0, max_value=1000, value=card_params["seed"], |
help="The seed value allows you to recreate the same image each time, or explore variations by changing the seed.") |
controlnet_conditioning_scale = st.slider(f"ControlNet Conditioning Scale for Holiday Card {i + 1}", min_value=0.0, max_value=1.0, value=card_params["controlnet_conditioning_scale"], step=0.1, |
help="Controls how much the ControlNet input influences the output. A lower value reduces its influence.") |
control_mode = st.slider(f"Control Mode for Holiday Card {i + 1}", min_value=0, max_value=2, value=card_params["control_mode"], |
help="Choose how much ControlNet should influence the final image. 0: None, 1: Partial, 2: Full.") |
st.session_state.card_params[i] = { |
"prompt": custom_prompt, |
"guidance_scale": guidance_scale, |
"num_inference_steps": num_inference_steps, |
"seed": seed, |
"controlnet_conditioning_scale": controlnet_conditioning_scale, |
"control_mode": control_mode |
} |
st.divider() |
st.subheader("Just hit play!") |
if st.button("Generate Holiday Cards"): |
with st.spinner("Processing..."): |
col1, col2 = st.columns(2) |
col3, col4 = st.columns(2) |
columns = [col1, col2, col3, col4] |
image_files = [] |
metadata = [] |
for i, params in enumerate(st.session_state.card_params): |
prompt = params['prompt'] |
guidance_scale = params['guidance_scale'] |
num_inference_steps = params['num_inference_steps'] |
seed = params['seed'] |
controlnet_conditioning_scale = params['controlnet_conditioning_scale'] |
control_mode = params['control_mode'] |
generated_image, processed_image, _ = call_control_net_api( |
st.session_state.uploaded_file, prompt, control_mode=control_mode, |
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, |
seed=seed, controlnet_conditioning_scale=controlnet_conditioning_scale |
) |
if generated_image is not None: |
generated_image = generated_image.resize(original_image.size) |
final_image = generated_image.copy() |
cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height)) |
cropped_width, cropped_height = cropped_original.size |
center_x = (final_image.width - cropped_width) // 2 |
center_y = (final_image.height - cropped_height) // 2 |
final_image.paste(cropped_original, (center_x, center_y)) |
img_byte_arr = BytesIO() |
final_image.save(img_byte_arr, format="PNG") |
img_byte_arr.seek(0) |
image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr)) |
metadata.append({ |
"Card": f"Holiday Card {i + 1}", |
"Prompt": prompt, |
"Guidance Scale": guidance_scale, |
"Inference Steps": num_inference_steps, |
"Seed": seed, |
"ControlNet Conditioning Scale": controlnet_conditioning_scale, |
"Control Mode": control_mode |
}) |
st.session_state.generated_cards[i] = { |
"image": final_image, |
"metadata": metadata[-1] |
} |
columns[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True) |
columns[i].write(f"**Prompt:** {prompt}") |
columns[i].write(f"**Guidance Scale:** {guidance_scale}") |
columns[i].write(f"**Inference Steps:** {num_inference_steps}") |
columns[i].write(f"**Seed:** {seed}") |
columns[i].write(f"**ControlNet Conditioning Scale:** {controlnet_conditioning_scale}") |
columns[i].write(f"**Control Mode:** {control_mode}") |
else: |
st.error(f"Failed to generate holiday card {i + 1}. Please try again.") |
if image_files: |
zip_buffer = BytesIO() |
with zipfile.ZipFile(zip_buffer, "w") as zf: |
for file_name, img_data in image_files: |
zf.writestr(file_name, img_data.getvalue()) |
metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}\nControlNet Conditioning Scale: {m['ControlNet Conditioning Scale']}\nControl Mode: {m['Control Mode']}" for m in metadata]) |
zf.writestr("metadata.txt", metadata_str) |
zip_buffer.seek(0) |
st.subheader("Step 4: Download & Share Your Masterpiece! 📥") |
st.markdown(""" |
- Once you're happy with your card, simply hit the download button to save your card and message as a **PNG** image. |
- You can also view and download any **previously generated holiday cards** from this session! |
""") |
st.download_button( |
label="Download all images and metadata as ZIP", |
data=zip_buffer, |
file_name="holiday_cards.zip", |
mime="application/zip" |
) |
st.divider() |
st.subheader("Previously Generated Holiday Cards") |
if any(st.session_state.generated_cards): |
col1, col2 = st.columns(2) |
col3, col4 = st.columns(2) |
columns = [col1, col2, col3, col4] |
image_files = [] |
metadata = [] |
for i, card in enumerate(st.session_state.generated_cards): |
if card and "metadata" in card: |
columns[i].image(card['image'], caption=f"Holiday Card {i + 1} (Saved)") |
card_metadata = card['metadata'] |
columns[i].write(f"**Prompt:** {card_metadata['Prompt']}") |
columns[i].write(f"**Guidance Scale:** {card_metadata['Guidance Scale']}") |
columns[i].write(f"**Inference Steps:** {card_metadata['Inference Steps']}") |
columns[i].write(f"**Seed:** {card_metadata['Seed']}") |
columns[i].write(f"**ControlNet Conditioning Scale:** {card_metadata['ControlNet Conditioning Scale']}") |
columns[i].write(f"**Control Mode:** {card_metadata['Control Mode']}") |
img_byte_arr = BytesIO() |
card['image'].save(img_byte_arr, format="PNG") |
img_byte_arr.seek(0) |
image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr)) |
metadata.append(card_metadata) |
if image_files: |
zip_buffer = BytesIO() |
with zipfile.ZipFile(zip_buffer, "w") as zf: |
for file_name, img_data in image_files: |
zf.writestr(file_name, img_data.getvalue()) |
metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}\nControlNet Conditioning Scale: {m['ControlNet Conditioning Scale']}\nControl Mode: {m['Control Mode']}" for m in metadata]) |
zf.writestr("metadata.txt", metadata_str) |