|
import streamlit as st |
|
import cv2 |
|
import requests |
|
from io import BytesIO |
|
from PIL import Image, ImageDraw |
|
import numpy as np |
|
import os |
|
from dotenv import load_dotenv |
|
import zipfile |
|
|
|
|
|
|
|
dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env') |
|
load_dotenv(dotenv_path, override=True) |
|
api_key = os.getenv("FIREWORKS_API_KEY") |
|
|
|
if not api_key: |
|
st.error("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.") |
|
st.stop() |
|
|
|
VALID_ASPECT_RATIOS = { |
|
(1, 1): "1:1", (21, 9): "21:9", (16, 9): "16:9", (3, 2): "3:2", (5, 4): "5:4", |
|
(4, 5): "4:5", (2, 3): "2:3", (9, 16): "9:16", (9, 21): "9:21", |
|
} |
|
|
|
def get_closest_aspect_ratio(width, height): |
|
aspect_ratio = width / height |
|
closest_ratio = min(VALID_ASPECT_RATIOS.keys(), key=lambda x: abs((x[0] / x[1]) - aspect_ratio)) |
|
return VALID_ASPECT_RATIOS[closest_ratio] |
|
|
|
def process_image(uploaded_image): |
|
image = np.array(Image.open(uploaded_image).convert('L')) |
|
edges = cv2.Canny(image, 100, 200) |
|
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) |
|
pil_image = Image.fromarray(edges_rgb) |
|
byte_arr = BytesIO() |
|
pil_image.save(byte_arr, format='JPEG') |
|
byte_arr.seek(0) |
|
return byte_arr, pil_image |
|
|
|
def call_control_net_api(uploaded_image, prompt, control_mode=0, guidance_scale=3.5, num_inference_steps=30, seed=0, controlnet_conditioning_scale=1.0): |
|
control_image, processed_image = process_image(uploaded_image) |
|
files = {'control_image': ('control_image.jpg', control_image, 'image/jpeg')} |
|
original_image = Image.open(uploaded_image) |
|
width, height = original_image.size |
|
aspect_ratio = get_closest_aspect_ratio(width, height) |
|
data = { |
|
'prompt': prompt, |
|
'control_mode': control_mode, |
|
'aspect_ratio': aspect_ratio, |
|
'guidance_scale': guidance_scale, |
|
'num_inference_steps': num_inference_steps, |
|
'seed': seed, |
|
'controlnet_conditioning_scale': controlnet_conditioning_scale |
|
} |
|
headers = { |
|
'accept': 'image/jpeg', |
|
'authorization': f'Bearer {api_key}', |
|
} |
|
response = requests.post('https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-controlnet-union/control_net', |
|
files=files, data=data, headers=headers) |
|
if response.status_code == 200: |
|
return Image.open(BytesIO(response.content)), processed_image, original_image |
|
else: |
|
st.error(f"Request failed with status code: {response.status_code}, Response: {response.text}") |
|
return None, None, None |
|
|
|
def draw_crop_preview(image, x, y, width, height): |
|
draw = ImageDraw.Draw(image) |
|
draw.rectangle([x, y, x + width, y + height], outline="red", width=2) |
|
return image |
|
|
|
|
|
|
|
logo_image = Image.open("img/fireworksai_logo.png") |
|
st.image(logo_image) |
|
|
|
st.title("🎨 Holiday Card Generator - Part A: Design & Ideation 🎨") |
|
st.markdown( |
|
"""Welcome to the first part of your holiday card creation journey! 🌟 Here, you’ll play around with different styles, prompts, and parameters to design the perfect card border before adding a personal message in Part B. Let your creativity flow! 🎉 |
|
|
|
### How it works: |
|
|
|
1. **🖼️ Upload Your Image:** Choose the image that will be the center of your card. |
|
2. **✂️ Crop It:** Adjust the crop to highlight the most important part of your image. |
|
3. **💡 Choose Your Style:** Select from festive border themes or input your own custom prompt to design something unique. |
|
4. **⚙️ Fine-Tune Parameters:** Experiment with guidance scales, seeds, inference steps, and more for the perfect aesthetic. |
|
5. **👀 Preview & Download:** See your generated holiday cards, tweak them until they’re just right, and download the final designs and metadata in a neat ZIP file! |
|
|
|
Once you’ve got the perfect look, head over to **Part B** to add your personal message and finalize your holiday card! 💌 |
|
|
|
""" |
|
) |
|
|
|
|
|
st.divider() |
|
st.subheader("🖼️ Step 1: Upload Your Picture!") |
|
st.markdown(""" |
|
- Click on the **Upload Image** button to select the image you want to feature on your holiday card. |
|
- Accepted formats: **JPG** or **PNG**. |
|
- Once uploaded, your image will appear on the screen—easy peasy! |
|
""") |
|
st.write("Upload the image that will be used for generating your holiday card.") |
|
|
|
|
|
if 'uploaded_file' not in st.session_state: |
|
st.session_state.uploaded_file = None |
|
if 'card_params' not in st.session_state: |
|
st.session_state.card_params = [{} for _ in range(4)] |
|
if 'generated_cards' not in st.session_state: |
|
st.session_state.generated_cards = [None for _ in range(4)] |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) |
|
st.divider() |
|
|
|
if uploaded_file is not None: |
|
st.session_state.uploaded_file = uploaded_file |
|
|
|
|
|
if st.session_state.uploaded_file is not None: |
|
original_image = Image.open(st.session_state.uploaded_file) |
|
st.image(original_image, caption="Uploaded Image", use_column_width=True) |
|
|
|
img_width, img_height = original_image.size |
|
|
|
|
|
st.divider() |
|
st.subheader("✂️ Step 2: Crop It Like It's Hot!") |
|
st.markdown(""" |
|
- Adjust the **crop** sliders to select the perfect area of your image. |
|
- This cropped section will be the centerpiece of your festive card. 🎨 |
|
- A preview will show you exactly how your crop looks before moving to the next step! |
|
""") |
|
st.write("Select the area you want to crop from the original image. This cropped portion will be used in the final card.") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
x_pos = st.slider("X position (Left-Right)", 0, img_width, img_width // 4, |
|
help="Move the slider to adjust the crop's left-right position.") |
|
crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos), |
|
help="Adjust the width of the crop.") |
|
with col2: |
|
y_pos = st.slider("Y position (Up-Down)", 0, img_height, img_height // 4, |
|
help="Move the slider to adjust the crop's up-down position.") |
|
crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos), |
|
help="Adjust the height of the crop.") |
|
|
|
preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height) |
|
st.image(preview_image, caption="Crop Preview", use_column_width=True) |
|
st.divider() |
|
|
|
st.subheader("⚙️ Step 3: Set Your Festive Border Design with Flux + Fireworks!") |
|
st.markdown(""" |
|
- Choose from a selection of holiday-themed borders like **snowflakes**, **Christmas lights**, or even **New Year's Eve fireworks**—all generated through **Flux models on Fireworks!** ✨ |
|
- Want to get creative? No problem—enter your own **custom prompt** to create a border that's as unique as you are, all powered by **Fireworks' inference API**. |
|
""") |
|
st.write("Customize the parameters for generating each holiday card. Each parameter will influence the final design and style of the card.") |
|
|
|
|
|
holiday_prompts = [ |
|
"A border of Festive snowflakes and winter patterns for a holiday card border", |
|
"A border of Joyful Christmas ornaments and lights decorating the edges", |
|
"A border of Warm and cozy fireplace scene with stockings and garlands", |
|
"A border of Colorful Hanukkah menorahs and dreidels along the border", |
|
"A border of New Year's Eve fireworks with stars and confetti framing the image" |
|
] |
|
|
|
|
|
for i in range(4): |
|
with st.expander(f"Holiday Card {i + 1} Parameters"): |
|
st.write(f"### Holiday Card {i + 1}") |
|
|
|
|
|
card_params = st.session_state.card_params[i] |
|
card_params.setdefault("prompt", f"Custom Prompt {i + 1}") |
|
card_params.setdefault("guidance_scale", 3.5) |
|
card_params.setdefault("num_inference_steps", 30) |
|
card_params.setdefault("seed", i * 100) |
|
card_params.setdefault("controlnet_conditioning_scale", 0.5) |
|
card_params.setdefault("control_mode", 0) |
|
|
|
|
|
selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_prompts) |
|
custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=card_params["prompt"]) if selected_prompt == "Custom" else selected_prompt |
|
|
|
|
|
guidance_scale = st.slider(f"Guidance Scale for Holiday Card {i + 1}", min_value=0.0, max_value=20.0, value=card_params["guidance_scale"], step=0.1, |
|
help="Adjusts how strongly the model follows the prompt. Higher values mean stronger adherence to the prompt.") |
|
num_inference_steps = st.slider(f"Number of Inference Steps for Holiday Card {i + 1}", min_value=1, max_value=100, value=card_params["num_inference_steps"], step=1, |
|
help="More inference steps can lead to a higher quality image but will take longer to generate.") |
|
seed = st.slider(f"Random Seed for Holiday Card {i + 1}", min_value=0, max_value=1000, value=card_params["seed"], |
|
help="The seed value allows you to recreate the same image each time, or explore variations by changing the seed.") |
|
controlnet_conditioning_scale = st.slider(f"ControlNet Conditioning Scale for Holiday Card {i + 1}", min_value=0.0, max_value=1.0, value=card_params["controlnet_conditioning_scale"], step=0.1, |
|
help="Controls how much the ControlNet input influences the output. A lower value reduces its influence.") |
|
control_mode = st.slider(f"Control Mode for Holiday Card {i + 1}", min_value=0, max_value=2, value=card_params["control_mode"], |
|
help="Choose how much ControlNet should influence the final image. 0: None, 1: Partial, 2: Full.") |
|
|
|
|
|
st.session_state.card_params[i] = { |
|
"prompt": custom_prompt, |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"seed": seed, |
|
"controlnet_conditioning_scale": controlnet_conditioning_scale, |
|
"control_mode": control_mode |
|
} |
|
|
|
st.divider() |
|
|
|
|
|
st.subheader("Just hit play!") |
|
if st.button("Generate Holiday Cards"): |
|
with st.spinner("Processing..."): |
|
|
|
col1, col2 = st.columns(2) |
|
col3, col4 = st.columns(2) |
|
columns = [col1, col2, col3, col4] |
|
|
|
|
|
image_files = [] |
|
metadata = [] |
|
|
|
|
|
for i, params in enumerate(st.session_state.card_params): |
|
prompt = params['prompt'] |
|
guidance_scale = params['guidance_scale'] |
|
num_inference_steps = params['num_inference_steps'] |
|
seed = params['seed'] |
|
controlnet_conditioning_scale = params['controlnet_conditioning_scale'] |
|
control_mode = params['control_mode'] |
|
|
|
|
|
generated_image, processed_image, _ = call_control_net_api( |
|
st.session_state.uploaded_file, prompt, control_mode=control_mode, |
|
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, |
|
seed=seed, controlnet_conditioning_scale=controlnet_conditioning_scale |
|
) |
|
|
|
if generated_image is not None: |
|
|
|
generated_image = generated_image.resize(original_image.size) |
|
|
|
|
|
final_image = generated_image.copy() |
|
|
|
|
|
cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height)) |
|
|
|
|
|
cropped_width, cropped_height = cropped_original.size |
|
|
|
|
|
center_x = (final_image.width - cropped_width) // 2 |
|
center_y = (final_image.height - cropped_height) // 2 |
|
|
|
|
|
final_image.paste(cropped_original, (center_x, center_y)) |
|
|
|
|
|
img_byte_arr = BytesIO() |
|
final_image.save(img_byte_arr, format="PNG") |
|
img_byte_arr.seek(0) |
|
image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr)) |
|
|
|
|
|
metadata.append({ |
|
"Card": f"Holiday Card {i + 1}", |
|
"Prompt": prompt, |
|
"Guidance Scale": guidance_scale, |
|
"Inference Steps": num_inference_steps, |
|
"Seed": seed, |
|
"ControlNet Conditioning Scale": controlnet_conditioning_scale, |
|
"Control Mode": control_mode |
|
}) |
|
|
|
|
|
st.session_state.generated_cards[i] = { |
|
"image": final_image, |
|
"metadata": metadata[-1] |
|
} |
|
|
|
|
|
columns[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True) |
|
|
|
|
|
columns[i].write(f"**Prompt:** {prompt}") |
|
columns[i].write(f"**Guidance Scale:** {guidance_scale}") |
|
columns[i].write(f"**Inference Steps:** {num_inference_steps}") |
|
columns[i].write(f"**Seed:** {seed}") |
|
columns[i].write(f"**ControlNet Conditioning Scale:** {controlnet_conditioning_scale}") |
|
columns[i].write(f"**Control Mode:** {control_mode}") |
|
else: |
|
st.error(f"Failed to generate holiday card {i + 1}. Please try again.") |
|
|
|
|
|
if image_files: |
|
zip_buffer = BytesIO() |
|
with zipfile.ZipFile(zip_buffer, "w") as zf: |
|
|
|
for file_name, img_data in image_files: |
|
zf.writestr(file_name, img_data.getvalue()) |
|
|
|
|
|
metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}\nControlNet Conditioning Scale: {m['ControlNet Conditioning Scale']}\nControl Mode: {m['Control Mode']}" for m in metadata]) |
|
zf.writestr("metadata.txt", metadata_str) |
|
|
|
zip_buffer.seek(0) |
|
st.subheader("Step 4: Download & Share Your Masterpiece! 📥") |
|
st.markdown(""" |
|
- Once you're happy with your card, simply hit the download button to save your card and message as a **PNG** image. |
|
- You can also view and download any **previously generated holiday cards** from this session! |
|
""") |
|
st.download_button( |
|
label="Download all images and metadata as ZIP", |
|
data=zip_buffer, |
|
file_name="holiday_cards.zip", |
|
mime="application/zip" |
|
) |
|
|
|
|
|
st.divider() |
|
st.subheader("Previously Generated Holiday Cards") |
|
if any(st.session_state.generated_cards): |
|
col1, col2 = st.columns(2) |
|
col3, col4 = st.columns(2) |
|
columns = [col1, col2, col3, col4] |
|
|
|
image_files = [] |
|
metadata = [] |
|
|
|
|
|
for i, card in enumerate(st.session_state.generated_cards): |
|
if card and "metadata" in card: |
|
columns[i].image(card['image'], caption=f"Holiday Card {i + 1} (Saved)") |
|
card_metadata = card['metadata'] |
|
columns[i].write(f"**Prompt:** {card_metadata['Prompt']}") |
|
columns[i].write(f"**Guidance Scale:** {card_metadata['Guidance Scale']}") |
|
columns[i].write(f"**Inference Steps:** {card_metadata['Inference Steps']}") |
|
columns[i].write(f"**Seed:** {card_metadata['Seed']}") |
|
columns[i].write(f"**ControlNet Conditioning Scale:** {card_metadata['ControlNet Conditioning Scale']}") |
|
columns[i].write(f"**Control Mode:** {card_metadata['Control Mode']}") |
|
|
|
|
|
img_byte_arr = BytesIO() |
|
card['image'].save(img_byte_arr, format="PNG") |
|
img_byte_arr.seek(0) |
|
image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr)) |
|
metadata.append(card_metadata) |
|
|
|
|
|
if image_files: |
|
zip_buffer = BytesIO() |
|
with zipfile.ZipFile(zip_buffer, "w") as zf: |
|
|
|
for file_name, img_data in image_files: |
|
zf.writestr(file_name, img_data.getvalue()) |
|
|
|
|
|
metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}\nControlNet Conditioning Scale: {m['ControlNet Conditioning Scale']}\nControl Mode: {m['Control Mode']}" for m in metadata]) |
|
zf.writestr("metadata.txt", metadata_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|