|
import streamlit as st |
|
from openai import OpenAI |
|
import requests |
|
from PIL import Image |
|
import io |
|
import os |
|
from datetime import datetime |
|
|
|
def convert_to_png(image_file): |
|
"""Convert uploaded image to PNG format and ensure it's under 4MB""" |
|
|
|
image = Image.open(image_file) |
|
|
|
|
|
if image.mode in ('RGBA', 'RGB'): |
|
image = image.convert('RGBA') |
|
else: |
|
image = image.convert('RGB').convert('RGBA') |
|
|
|
|
|
byte_arr = io.BytesIO() |
|
image.save(byte_arr, format='PNG', optimize=True) |
|
byte_arr.seek(0) |
|
|
|
|
|
if byte_arr.getbuffer().nbytes > 4 * 1024 * 1024: |
|
|
|
while byte_arr.getbuffer().nbytes > 4 * 1024 * 1024: |
|
width, height = image.size |
|
new_width = int(width * 0.9) |
|
new_height = int(height * 0.9) |
|
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
|
byte_arr = io.BytesIO() |
|
image.save(byte_arr, format='PNG', optimize=True) |
|
byte_arr.seek(0) |
|
|
|
return byte_arr |
|
|
|
def validate_image(uploaded_file): |
|
"""Validate image size and format""" |
|
if uploaded_file.size > 4 * 1024 * 1024: |
|
return False, "File size must be less than 4MB" |
|
return True, "OK" |
|
|
|
def save_uploaded_file(uploaded_file, folder="uploads"): |
|
"""Save uploaded file to a temporary folder and return the path""" |
|
if not os.path.exists(folder): |
|
os.makedirs(folder) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
image_bytes = convert_to_png(uploaded_file) |
|
|
|
|
|
file_path = os.path.join(folder, f"{timestamp}.png") |
|
with open(file_path, "wb") as f: |
|
f.write(image_bytes.getvalue()) |
|
|
|
return file_path |
|
|
|
def download_image(url, folder="generated"): |
|
"""Download image from URL and save it""" |
|
if not os.path.exists(folder): |
|
os.makedirs(folder) |
|
|
|
response = requests.get(url) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
file_path = os.path.join(folder, f"generated_image_{timestamp}.png") |
|
|
|
if response.status_code == 200: |
|
with open(file_path, "wb") as f: |
|
f.write(response.content) |
|
return file_path |
|
return None |
|
|
|
def main(): |
|
st.title("π¨ Image Editor with DALL-E 2") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Configuration") |
|
api_key = st.text_input("Enter your OpenAI API key", type="password") |
|
st.markdown(""" |
|
### How to get an API key |
|
1. Go to [OpenAI API Keys](https://platform.openai.com/api-keys) |
|
2. Create a new secret key |
|
3. Copy and paste it here |
|
""") |
|
|
|
|
|
size_option = st.selectbox( |
|
"Select image size:", |
|
["1024x1024", "512x512", "256x256"] |
|
) |
|
|
|
|
|
num_images = st.slider("Number of images to generate", 1, 4, 1) |
|
|
|
|
|
st.markdown(""" |
|
### Requirements: |
|
- Original image and mask must be less than 4MB |
|
- Images will be automatically converted to PNG format |
|
- Images larger than 4MB will be automatically resized |
|
""") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Original Image") |
|
original_image = st.file_uploader( |
|
"Upload original image", |
|
type=["png", "jpg", "jpeg"] |
|
) |
|
if original_image: |
|
|
|
valid, message = validate_image(original_image) |
|
if not valid: |
|
st.warning(f"Original image: {message}. Image will be automatically resized.") |
|
|
|
try: |
|
|
|
image = Image.open(original_image) |
|
st.image(image, caption="Original Image", use_container_width=True) |
|
st.caption(f"Original size: {original_image.size/1024/1024:.2f}MB") |
|
except Exception as e: |
|
st.error(f"Error loading image: {str(e)}") |
|
|
|
with col2: |
|
st.subheader("Mask Image") |
|
mask_image = st.file_uploader( |
|
"Upload mask image", |
|
type=["png", "jpg", "jpeg"] |
|
) |
|
if mask_image: |
|
|
|
valid, message = validate_image(mask_image) |
|
if not valid: |
|
st.warning(f"Mask image: {message}. Image will be automatically resized.") |
|
|
|
try: |
|
|
|
image = Image.open(mask_image) |
|
st.image(image, caption="Mask Image", use_column_width=True) |
|
st.caption(f"Original size: {mask_image.size/1024/1024:.2f}MB") |
|
except Exception as e: |
|
st.error(f"Error loading mask: {str(e)}") |
|
|
|
|
|
prompt = st.text_area( |
|
"Enter your prompt:", |
|
placeholder="Describe the changes you want to make to the image...", |
|
help="Be specific about what you want to add or modify in the masked area" |
|
) |
|
|
|
|
|
if st.button("Generate Edited Image"): |
|
if not api_key: |
|
st.error("Please enter your OpenAI API key in the sidebar.") |
|
return |
|
|
|
if not original_image or not mask_image: |
|
st.error("Please upload both an original image and a mask image.") |
|
return |
|
|
|
if not prompt: |
|
st.error("Please enter a prompt describing the desired changes.") |
|
return |
|
|
|
try: |
|
with st.spinner("Processing images and generating edited version..."): |
|
|
|
original_path = save_uploaded_file(original_image) |
|
mask_path = save_uploaded_file(mask_image) |
|
|
|
|
|
client = OpenAI(api_key=api_key) |
|
|
|
|
|
response = client.images.edit( |
|
model="dall-e-2", |
|
image=open(original_path, "rb"), |
|
mask=open(mask_path, "rb"), |
|
prompt=prompt, |
|
n=num_images, |
|
size=size_option |
|
) |
|
|
|
|
|
st.subheader("Generated Images") |
|
cols = st.columns(num_images) |
|
|
|
for idx, image_data in enumerate(response.data): |
|
|
|
saved_image_path = download_image(image_data.url) |
|
|
|
if saved_image_path: |
|
with cols[idx]: |
|
st.image(saved_image_path, caption=f"Generated Image {idx+1}") |
|
|
|
|
|
with open(saved_image_path, "rb") as file: |
|
st.download_button( |
|
label=f"Download Image {idx+1}", |
|
data=file, |
|
file_name=f"edited_image_{idx+1}.png", |
|
mime="image/png" |
|
) |
|
|
|
|
|
for path in [original_path, mask_path]: |
|
if os.path.exists(path): |
|
os.remove(path) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
finally: |
|
|
|
for path in [original_path, mask_path]: |
|
if 'path' in locals() and os.path.exists(path): |
|
try: |
|
os.remove(path) |
|
except Exception: |
|
pass |
|
|
|
if __name__ == "__main__": |
|
main() |