DexterSptizu's picture
Update app.py
638a59f verified
import streamlit as st
from openai import OpenAI
import requests
from PIL import Image
import io
import os
from datetime import datetime
def convert_to_png(image_file):
"""Convert uploaded image to PNG format and ensure it's under 4MB"""
# Open the image using PIL
image = Image.open(image_file)
# Convert to RGBA if not already
if image.mode in ('RGBA', 'RGB'):
image = image.convert('RGBA')
else:
image = image.convert('RGB').convert('RGBA')
# Save to bytes with PNG format
byte_arr = io.BytesIO()
image.save(byte_arr, format='PNG', optimize=True)
byte_arr.seek(0)
# Check if size is under 4MB
if byte_arr.getbuffer().nbytes > 4 * 1024 * 1024:
# If image is too large, resize it while maintaining aspect ratio
while byte_arr.getbuffer().nbytes > 4 * 1024 * 1024:
width, height = image.size
new_width = int(width * 0.9) # Reduce by 10%
new_height = int(height * 0.9)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
byte_arr = io.BytesIO()
image.save(byte_arr, format='PNG', optimize=True)
byte_arr.seek(0)
return byte_arr
def validate_image(uploaded_file):
"""Validate image size and format"""
if uploaded_file.size > 4 * 1024 * 1024: # 4MB in bytes
return False, "File size must be less than 4MB"
return True, "OK"
def save_uploaded_file(uploaded_file, folder="uploads"):
"""Save uploaded file to a temporary folder and return the path"""
if not os.path.exists(folder):
os.makedirs(folder)
# Generate a timestamp for unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Convert to PNG and validate size
image_bytes = convert_to_png(uploaded_file)
# Save the converted PNG
file_path = os.path.join(folder, f"{timestamp}.png")
with open(file_path, "wb") as f:
f.write(image_bytes.getvalue())
return file_path
def download_image(url, folder="generated"):
"""Download image from URL and save it"""
if not os.path.exists(folder):
os.makedirs(folder)
response = requests.get(url)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_path = os.path.join(folder, f"generated_image_{timestamp}.png")
if response.status_code == 200:
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
return None
def main():
st.title("🎨 Image Editor with DALL-E 2")
# Sidebar for API key
with st.sidebar:
st.header("Configuration")
api_key = st.text_input("Enter your OpenAI API key", type="password")
st.markdown("""
### How to get an API key
1. Go to [OpenAI API Keys](https://platform.openai.com/api-keys)
2. Create a new secret key
3. Copy and paste it here
""")
# Size selection
size_option = st.selectbox(
"Select image size:",
["1024x1024", "512x512", "256x256"]
)
# Number of images
num_images = st.slider("Number of images to generate", 1, 4, 1)
# Main content
st.markdown("""
### Requirements:
- Original image and mask must be less than 4MB
- Images will be automatically converted to PNG format
- Images larger than 4MB will be automatically resized
""")
# File uploaders
col1, col2 = st.columns(2)
with col1:
st.subheader("Original Image")
original_image = st.file_uploader(
"Upload original image",
type=["png", "jpg", "jpeg"]
)
if original_image:
# Validate image
valid, message = validate_image(original_image)
if not valid:
st.warning(f"Original image: {message}. Image will be automatically resized.")
try:
# Display image preview
image = Image.open(original_image)
st.image(image, caption="Original Image", use_container_width=True)
st.caption(f"Original size: {original_image.size/1024/1024:.2f}MB")
except Exception as e:
st.error(f"Error loading image: {str(e)}")
with col2:
st.subheader("Mask Image")
mask_image = st.file_uploader(
"Upload mask image",
type=["png", "jpg", "jpeg"]
)
if mask_image:
# Validate image
valid, message = validate_image(mask_image)
if not valid:
st.warning(f"Mask image: {message}. Image will be automatically resized.")
try:
# Display image preview
image = Image.open(mask_image)
st.image(image, caption="Mask Image", use_column_width=True)
st.caption(f"Original size: {mask_image.size/1024/1024:.2f}MB")
except Exception as e:
st.error(f"Error loading mask: {str(e)}")
# Prompt input
prompt = st.text_area(
"Enter your prompt:",
placeholder="Describe the changes you want to make to the image...",
help="Be specific about what you want to add or modify in the masked area"
)
# Generate button
if st.button("Generate Edited Image"):
if not api_key:
st.error("Please enter your OpenAI API key in the sidebar.")
return
if not original_image or not mask_image:
st.error("Please upload both an original image and a mask image.")
return
if not prompt:
st.error("Please enter a prompt describing the desired changes.")
return
try:
with st.spinner("Processing images and generating edited version..."):
# Save and convert uploaded files
original_path = save_uploaded_file(original_image)
mask_path = save_uploaded_file(mask_image)
# Initialize OpenAI client
client = OpenAI(api_key=api_key)
# Make the API call
response = client.images.edit(
model="dall-e-2",
image=open(original_path, "rb"),
mask=open(mask_path, "rb"),
prompt=prompt,
n=num_images,
size=size_option
)
# Display results
st.subheader("Generated Images")
cols = st.columns(num_images)
for idx, image_data in enumerate(response.data):
# Download and save the generated image
saved_image_path = download_image(image_data.url)
if saved_image_path:
with cols[idx]:
st.image(saved_image_path, caption=f"Generated Image {idx+1}")
# Add download button for each image
with open(saved_image_path, "rb") as file:
st.download_button(
label=f"Download Image {idx+1}",
data=file,
file_name=f"edited_image_{idx+1}.png",
mime="image/png"
)
# Cleanup temporary files
for path in [original_path, mask_path]:
if os.path.exists(path):
os.remove(path)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
finally:
# Cleanup temporary files
for path in [original_path, mask_path]:
if 'path' in locals() and os.path.exists(path):
try:
os.remove(path)
except Exception:
pass
if __name__ == "__main__":
main()