File size: 19,789 Bytes
7c1b343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import streamlit as st
import cv2
import requests
from io import BytesIO
from PIL import Image, ImageDraw
import numpy as np
import os
from dotenv import load_dotenv
import zipfile


# Load environment variables
dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env')
load_dotenv(dotenv_path, override=True)
api_key = os.getenv("FIREWORKS_API_KEY")

if not api_key:
    st.error("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
    st.stop()

VALID_ASPECT_RATIOS = {
    (1, 1): "1:1", (21, 9): "21:9", (16, 9): "16:9", (3, 2): "3:2", (5, 4): "5:4",
    (4, 5): "4:5", (2, 3): "2:3", (9, 16): "9:16", (9, 21): "9:21",
}

def get_closest_aspect_ratio(width, height):
    aspect_ratio = width / height
    closest_ratio = min(VALID_ASPECT_RATIOS.keys(), key=lambda x: abs((x[0] / x[1]) - aspect_ratio))
    return VALID_ASPECT_RATIOS[closest_ratio]

def process_image(uploaded_image):
    image = np.array(Image.open(uploaded_image).convert('L'))
    edges = cv2.Canny(image, 100, 200)
    edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
    pil_image = Image.fromarray(edges_rgb)
    byte_arr = BytesIO()
    pil_image.save(byte_arr, format='JPEG')
    byte_arr.seek(0)
    return byte_arr, pil_image

def call_control_net_api(uploaded_image, prompt, control_mode=0, guidance_scale=3.5, num_inference_steps=30, seed=0, controlnet_conditioning_scale=1.0):
    control_image, processed_image = process_image(uploaded_image)
    files = {'control_image': ('control_image.jpg', control_image, 'image/jpeg')}
    original_image = Image.open(uploaded_image)
    width, height = original_image.size
    aspect_ratio = get_closest_aspect_ratio(width, height)
    data = {
        'prompt': prompt,
        'control_mode': control_mode,
        'aspect_ratio': aspect_ratio,
        'guidance_scale': guidance_scale,
        'num_inference_steps': num_inference_steps,
        'seed': seed,
        'controlnet_conditioning_scale': controlnet_conditioning_scale
    }
    headers = {
        'accept': 'image/jpeg',
        'authorization': f'Bearer {api_key}',
    }
    response = requests.post('https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-controlnet-union/control_net', 
                             files=files, data=data, headers=headers)
    if response.status_code == 200:
        return Image.open(BytesIO(response.content)), processed_image, original_image
    else:
        st.error(f"Request failed with status code: {response.status_code}, Response: {response.text}")
        return None, None, None

def draw_crop_preview(image, x, y, width, height):
    draw = ImageDraw.Draw(image)
    draw.rectangle([x, y, x + width, y + height], outline="red", width=2)
    return image


# Fireworks Logo at the top
logo_image = Image.open("img/fireworksai_logo.png")
st.image(logo_image)

st.title("🎨 Holiday Card Generator - Part A: Design & Ideation 🎨")
st.markdown(
    """Welcome to the first part of your holiday card creation journey! 🌟 Here, you’ll play around with different styles, prompts, and parameters to design the perfect card border before adding a personal message in Part B. Let your creativity flow! 🎉

### How it works:

1. **🖼️ Upload Your Image:** Choose the image that will be the center of your card.
2. **✂️ Crop It:** Adjust the crop to highlight the most important part of your image.
3. **💡 Choose Your Style:** Select from festive border themes or input your own custom prompt to design something unique.
4. **⚙️ Fine-Tune Parameters:** Experiment with guidance scales, seeds, inference steps, and more for the perfect aesthetic.
5. **👀 Preview & Download:** See your generated holiday cards, tweak them until they’re just right, and download the final designs and metadata in a neat ZIP file!

Once you’ve got the perfect look, head over to **Part B** to add your personal message and finalize your holiday card! 💌

"""
)

# Add divider and subheader for uploading section
st.divider()
st.subheader("🖼️ Step 1: Upload Your Picture!")
st.markdown(""" 
- Click on the **Upload Image** button to select the image you want to feature on your holiday card.  
- Accepted formats: **JPG** or **PNG**.
- Once uploaded, your image will appear on the screen—easy peasy!
""")
st.write("Upload the image that will be used for generating your holiday card.")

# Initialize session state variables if they don't exist
if 'uploaded_file' not in st.session_state:
    st.session_state.uploaded_file = None
if 'card_params' not in st.session_state:
    st.session_state.card_params = [{} for _ in range(4)]  # Initialize for 4 cards
if 'generated_cards' not in st.session_state:
    st.session_state.generated_cards = [None for _ in range(4)]  # Store generated images and metadata

# File uploader - if a file is uploaded, save it in session_state
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
st.divider()

if uploaded_file is not None:
    st.session_state.uploaded_file = uploaded_file  # Save to session state

# Load the image from session state
if st.session_state.uploaded_file is not None:
    original_image = Image.open(st.session_state.uploaded_file)
    st.image(original_image, caption="Uploaded Image", use_column_width=True)

    img_width, img_height = original_image.size

    # Add a divider and subheader for the crop section
    st.divider()
    st.subheader("✂️ Step 2: Crop It Like It's Hot!")
    st.markdown(""" 
    - Adjust the **crop** sliders to select the perfect area of your image.  
    - This cropped section will be the centerpiece of your festive card. 🎨  
    - A preview will show you exactly how your crop looks before moving to the next step!
""")
    st.write("Select the area you want to crop from the original image. This cropped portion will be used in the final card.")

    col1, col2 = st.columns(2)
    with col1:
        x_pos = st.slider("X position (Left-Right)", 0, img_width, img_width // 4,
                          help="Move the slider to adjust the crop's left-right position.")
        crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos),
                               help="Adjust the width of the crop.")
    with col2:
        y_pos = st.slider("Y position (Up-Down)", 0, img_height, img_height // 4,
                          help="Move the slider to adjust the crop's up-down position.")
        crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos),
                                help="Adjust the height of the crop.")

    preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height)
    st.image(preview_image, caption="Crop Preview", use_column_width=True)
    st.divider()

    st.subheader("⚙️ Step 3: Set Your Festive Border Design with Flux + Fireworks!")
    st.markdown(""" 
        - Choose from a selection of holiday-themed borders like **snowflakes**, **Christmas lights**, or even **New Year's Eve fireworks**—all generated through **Flux models on Fireworks!** ✨
        - Want to get creative? No problem—enter your own **custom prompt** to create a border that's as unique as you are, all powered by **Fireworks' inference API**.
""")
    st.write("Customize the parameters for generating each holiday card. Each parameter will influence the final design and style of the card.")

    # Define the list of suggested holiday prompts
    holiday_prompts = [
        "A border of Festive snowflakes and winter patterns for a holiday card border",
        "A border of Joyful Christmas ornaments and lights decorating the edges",
        "A border of Warm and cozy fireplace scene with stockings and garlands",
        "A border of Colorful Hanukkah menorahs and dreidels along the border",
        "A border of New Year's Eve fireworks with stars and confetti framing the image"
    ]

    # Define input fields for each holiday card's parameters in expanders
    for i in range(4):
        with st.expander(f"Holiday Card {i + 1} Parameters"):
            st.write(f"### Holiday Card {i + 1}")
            
            # Get the card params from session_state if they exist
            card_params = st.session_state.card_params[i]
            card_params.setdefault("prompt", f"Custom Prompt {i + 1}")
            card_params.setdefault("guidance_scale", 3.5)
            card_params.setdefault("num_inference_steps", 30)
            card_params.setdefault("seed", i * 100)
            card_params.setdefault("controlnet_conditioning_scale", 0.5)
            card_params.setdefault("control_mode", 0)
            
            # Dropdown to choose a suggested holiday prompt or enter custom prompt
            selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_prompts)
            custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=card_params["prompt"]) if selected_prompt == "Custom" else selected_prompt
            
            # Parameter sliders for each holiday card with tooltips
            guidance_scale = st.slider(f"Guidance Scale for Holiday Card {i + 1}", min_value=0.0, max_value=20.0, value=card_params["guidance_scale"], step=0.1,
                                       help="Adjusts how strongly the model follows the prompt. Higher values mean stronger adherence to the prompt.")
            num_inference_steps = st.slider(f"Number of Inference Steps for Holiday Card {i + 1}", min_value=1, max_value=100, value=card_params["num_inference_steps"], step=1,
                                            help="More inference steps can lead to a higher quality image but will take longer to generate.")
            seed = st.slider(f"Random Seed for Holiday Card {i + 1}", min_value=0, max_value=1000, value=card_params["seed"],
                             help="The seed value allows you to recreate the same image each time, or explore variations by changing the seed.")
            controlnet_conditioning_scale = st.slider(f"ControlNet Conditioning Scale for Holiday Card {i + 1}", min_value=0.0, max_value=1.0, value=card_params["controlnet_conditioning_scale"], step=0.1,
                                                      help="Controls how much the ControlNet input influences the output. A lower value reduces its influence.")
            control_mode = st.slider(f"Control Mode for Holiday Card {i + 1}", min_value=0, max_value=2, value=card_params["control_mode"],
                                     help="Choose how much ControlNet should influence the final image. 0: None, 1: Partial, 2: Full.")
            
            # Update session state with the latest parameters
            st.session_state.card_params[i] = {
                "prompt": custom_prompt,
                "guidance_scale": guidance_scale,
                "num_inference_steps": num_inference_steps,
                "seed": seed,
                "controlnet_conditioning_scale": controlnet_conditioning_scale,
                "control_mode": control_mode
            }

    st.divider()  # Add a divider before the "Generate" button

    # Generate the holiday cards
    st.subheader("Just hit play!")
    if st.button("Generate Holiday Cards"):
        with st.spinner("Processing..."):
            # Create a column layout for displaying cards side by side
            col1, col2 = st.columns(2)
            col3, col4 = st.columns(2)
            columns = [col1, col2, col3, col4]  # To display images in a 2x2 grid

            # Store image bytes and metadata in lists
            image_files = []
            metadata = []

            # Loop through each card's parameters and generate the holiday card
            for i, params in enumerate(st.session_state.card_params):
                prompt = params['prompt']
                guidance_scale = params['guidance_scale']
                num_inference_steps = params['num_inference_steps']
                seed = params['seed']
                controlnet_conditioning_scale = params['controlnet_conditioning_scale']
                control_mode = params['control_mode']
                
                # Generate the holiday card using the current parameters
                generated_image, processed_image, _ = call_control_net_api(
                    st.session_state.uploaded_file, prompt, control_mode=control_mode, 
                    guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
                    seed=seed, controlnet_conditioning_scale=controlnet_conditioning_scale
                )

                if generated_image is not None:
                    # Resize generated_image to match original_image size
                    generated_image = generated_image.resize(original_image.size)

                    # Create a copy of the generated image
                    final_image = generated_image.copy()

                    # Crop the selected portion of the original image
                    cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height))

                    # Get the size of the cropped image
                    cropped_width, cropped_height = cropped_original.size

                    # Calculate the center of the generated image
                    center_x = (final_image.width - cropped_width) // 2
                    center_y = (final_image.height - cropped_height) // 2

                    # Paste the cropped portion of the original image onto the generated image at the calculated center
                    final_image.paste(cropped_original, (center_x, center_y))

                    # Save image to BytesIO
                    img_byte_arr = BytesIO()
                    final_image.save(img_byte_arr, format="PNG")
                    img_byte_arr.seek(0)
                    image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr))

                    # Store metadata
                    metadata.append({
                        "Card": f"Holiday Card {i + 1}",
                        "Prompt": prompt,
                        "Guidance Scale": guidance_scale,
                        "Inference Steps": num_inference_steps,
                        "Seed": seed,
                        "ControlNet Conditioning Scale": controlnet_conditioning_scale,
                        "Control Mode": control_mode
                    })

                    # Persist image and metadata in session state
                    st.session_state.generated_cards[i] = {
                        "image": final_image,
                        "metadata": metadata[-1]
                    }

                    # Display the final holiday card in one of the columns
                    columns[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True)

                    # Display the parameters used for this card
                    columns[i].write(f"**Prompt:** {prompt}")
                    columns[i].write(f"**Guidance Scale:** {guidance_scale}")
                    columns[i].write(f"**Inference Steps:** {num_inference_steps}")
                    columns[i].write(f"**Seed:** {seed}")
                    columns[i].write(f"**ControlNet Conditioning Scale:** {controlnet_conditioning_scale}")
                    columns[i].write(f"**Control Mode:** {control_mode}")
                else:
                    st.error(f"Failed to generate holiday card {i + 1}. Please try again.")

            # Create ZIP file for download
            if image_files:
                zip_buffer = BytesIO()
                with zipfile.ZipFile(zip_buffer, "w") as zf:
                    # Add images to the ZIP file
                    for file_name, img_data in image_files:
                        zf.writestr(file_name, img_data.getvalue())
                    
                    # Add metadata to the ZIP file
                    metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}\nControlNet Conditioning Scale: {m['ControlNet Conditioning Scale']}\nControl Mode: {m['Control Mode']}" for m in metadata])
                    zf.writestr("metadata.txt", metadata_str)

                zip_buffer.seek(0)
                st.subheader("Step 4: Download & Share Your Masterpiece! 📥")
                st.markdown(""" 
                - Once you're happy with your card, simply hit the download button to save your card and message as a **PNG** image.
                - You can also view and download any **previously generated holiday cards** from this session!
                            """)
                st.download_button(
                    label="Download all images and metadata as ZIP",
                    data=zip_buffer,
                    file_name="holiday_cards.zip",
                    mime="application/zip"
                )

# Display previously generated cards if they exist and create a ZIP download button
st.divider()
st.subheader("Previously Generated Holiday Cards")
if any(st.session_state.generated_cards):
    col1, col2 = st.columns(2)
    col3, col4 = st.columns(2)
    columns = [col1, col2, col3, col4]  # Display in 2x2 grid

    image_files = []
    metadata = []  # Ensure metadata is initialized as a list

    # Loop through previously generated cards
    for i, card in enumerate(st.session_state.generated_cards):
        if card and "metadata" in card:  # Ensure the card exists and has metadata
            columns[i].image(card['image'], caption=f"Holiday Card {i + 1} (Saved)")
            card_metadata = card['metadata']  # Retrieve the metadata for the current card
            columns[i].write(f"**Prompt:** {card_metadata['Prompt']}")
            columns[i].write(f"**Guidance Scale:** {card_metadata['Guidance Scale']}")
            columns[i].write(f"**Inference Steps:** {card_metadata['Inference Steps']}")
            columns[i].write(f"**Seed:** {card_metadata['Seed']}")
            columns[i].write(f"**ControlNet Conditioning Scale:** {card_metadata['ControlNet Conditioning Scale']}")
            columns[i].write(f"**Control Mode:** {card_metadata['Control Mode']}")

            # Add each image and its metadata to prepare for ZIP download
            img_byte_arr = BytesIO()
            card['image'].save(img_byte_arr, format="PNG")
            img_byte_arr.seek(0)
            image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr))
            metadata.append(card_metadata)  # Append card's metadata to the metadata list

    # Provide a ZIP download button for previously generated cards
    if image_files:
        zip_buffer = BytesIO()
        with zipfile.ZipFile(zip_buffer, "w") as zf:
            # Add images to the ZIP file
            for file_name, img_data in image_files:
                zf.writestr(file_name, img_data.getvalue())

            # Add metadata to the ZIP file
            metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}\nControlNet Conditioning Scale: {m['ControlNet Conditioning Scale']}\nControl Mode: {m['Control Mode']}" for m in metadata])
            zf.writestr("metadata.txt", metadata_str)

# """         zip_buffer.seek(0)
#         st.divider()
#         st.subheader("Save images & metadata")
#         st.download_button(
#             label="Download all previously generated images and metadata as ZIP",
#             data=zip_buffer,
#             file_name="holiday_cards.zip",
#             mime="application/zip"
#         )
#  """