|
import streamlit as st |
|
import requests |
|
from openai import OpenAI |
|
from PIL import Image |
|
import io |
|
import os |
|
from datetime import datetime |
|
|
|
def preprocess_image(uploaded_file): |
|
""" |
|
Preprocess the image to meet OpenAI's requirements: |
|
- Convert to PNG |
|
- Ensure file size is less than 4MB |
|
- Resize if necessary while maintaining aspect ratio |
|
""" |
|
|
|
if not os.path.exists("temp"): |
|
os.makedirs("temp") |
|
|
|
|
|
image = Image.open(uploaded_file) |
|
|
|
|
|
if image.mode == 'RGBA': |
|
image = image.convert('RGB') |
|
|
|
|
|
max_size = 1024 |
|
ratio = min(max_size/image.width, max_size/image.height) |
|
new_size = (int(image.width*ratio), int(image.height*ratio)) |
|
|
|
|
|
if image.width > max_size or image.height > max_size: |
|
image = image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
temp_path = f"temp/processed_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
|
image.save(temp_path, "PNG", optimize=True) |
|
|
|
|
|
while os.path.getsize(temp_path) > 4*1024*1024: |
|
image = image.resize( |
|
(int(image.width*0.9), int(image.height*0.9)), |
|
Image.Resampling.LANCZOS |
|
) |
|
image.save(temp_path, "PNG", optimize=True) |
|
|
|
return temp_path |
|
|
|
def save_image_from_url(image_url, index): |
|
"""Save image from URL to local file""" |
|
response = requests.get(image_url) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_path = f"generated_variations/{timestamp}_variation_{index}.png" |
|
|
|
if not os.path.exists("generated_variations"): |
|
os.makedirs("generated_variations") |
|
|
|
with open(output_path, "wb") as f: |
|
f.write(response.content) |
|
return output_path |
|
|
|
def main(): |
|
st.title("OpenAI Image Variation Generator") |
|
|
|
|
|
st.sidebar.header("Settings") |
|
api_key = st.sidebar.text_input("Enter OpenAI API Key", type="password") |
|
|
|
if not api_key: |
|
st.warning("Please enter your OpenAI API key in the sidebar to continue.") |
|
return |
|
|
|
|
|
st.write("Upload an image to generate variations using DALL-E 2") |
|
|
|
|
|
st.info("Please upload a PNG, JPG, or JPEG image. The image will be automatically processed to meet OpenAI's requirements (PNG format, < 4MB).") |
|
uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"]) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
num_variations = st.slider("Number of variations", min_value=1, max_value=4, value=1) |
|
with col2: |
|
size_options = ["1024x1024", "512x512", "256x256"] |
|
selected_size = st.selectbox("Image size", size_options) |
|
|
|
if uploaded_file is not None: |
|
try: |
|
|
|
st.subheader("Uploaded Image") |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
|
|
|
|
if st.button("Generate Variations"): |
|
try: |
|
|
|
with st.spinner("Processing image..."): |
|
temp_path = preprocess_image(uploaded_file) |
|
|
|
|
|
file_size_mb = os.path.getsize(temp_path) / (1024 * 1024) |
|
st.success(f"Image processed successfully! File size: {file_size_mb:.2f}MB") |
|
|
|
|
|
client = OpenAI(api_key=api_key) |
|
|
|
with st.spinner("Generating variations..."): |
|
|
|
response = client.images.create_variation( |
|
model="dall-e-2", |
|
image=open(temp_path, "rb"), |
|
n=num_variations, |
|
size=selected_size |
|
) |
|
|
|
|
|
st.subheader("Generated Variations") |
|
cols = st.columns(num_variations) |
|
|
|
for idx, image_data in enumerate(response.data): |
|
|
|
saved_path = save_image_from_url(image_data.url, idx) |
|
with cols[idx]: |
|
st.image(saved_path, caption=f"Variation {idx+1}", use_container_width=True) |
|
with open(saved_path, "rb") as file: |
|
st.download_button( |
|
label=f"Download Variation {idx+1}", |
|
data=file, |
|
file_name=f"variation_{idx+1}.png", |
|
mime="image/png" |
|
) |
|
|
|
|
|
os.remove(temp_path) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
if "invalid_request_error" in str(e): |
|
st.info("Please ensure your image meets OpenAI's requirements: PNG format, less than 4MB, and appropriate content.") |
|
|
|
except Exception as e: |
|
st.error(f"Error loading image: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
main() |