Spaces:
Running
Running
import base64 | |
import os | |
import gradio as gr | |
from google import genai | |
from google.genai import types | |
from google.genai.types import HarmBlockThreshold | |
from PIL import Image | |
from io import BytesIO | |
import tempfile | |
from dotenv import load_dotenv | |
import warnings | |
import io | |
# Load environment variables from .env file | |
load_dotenv() | |
def swap_clothing(person_image, clothing_image): | |
""" | |
Generate an image where the person from the first image is wearing clothing from the second image. | |
Args: | |
person_image: The image containing the person | |
clothing_image: The image containing the clothing to swap | |
Returns: | |
The generated image with the clothing swapped and any relevant messages | |
""" | |
# Capture warnings in a string buffer | |
warning_buffer = io.StringIO() | |
warnings.filterwarnings('always') # Ensure all warnings are shown | |
# Initialize variables outside the try block | |
temp_files = [] | |
uploaded_files = [] | |
client = None | |
output_image = None | |
output_text = "" | |
with warnings.catch_warnings(record=True) as warning_list: | |
try: | |
# Check if both images are provided | |
if person_image is None or clothing_image is None: | |
return None, "Please upload both images." | |
# Get API key from environment variables | |
api_key = os.environ.get("GEMINI_API_KEY") | |
if not api_key: | |
return None, "GEMINI_API_KEY not found in environment variables." | |
# Create a fresh client instance for each request | |
client = genai.Client(api_key=api_key) | |
# Save both uploaded images to temporary files | |
for img, prefix in [(person_image, "person"), (clothing_image, "clothing")]: | |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file: | |
img.save(temp_file.name) | |
temp_files.append(temp_file.name) | |
# Upload both files to Gemini with fresh file uploads | |
uploaded_files = [ | |
client.files.upload(file=temp_files[0]), # person image | |
client.files.upload(file=temp_files[1]), # clothing image | |
] | |
# Create the prompt | |
prompt = ''' | |
Edit the person's clothing by swapping it with the clothing in the clothing image. | |
Retain the same face, facial features, pose and background from the person image. | |
The output image should be an image of the person wearing the clothing from the clothing image with the style of clothing image. | |
The image pose and background should be the same as the person image but with the new clothing: | |
''' | |
contents = [ | |
types.Content( | |
role="user", | |
parts=[ | |
types.Part.from_text(text="This is the person image. Do not change the face or features of the person. Pay attention and retain the face, environment, background, pose, facial features."), | |
types.Part.from_uri( | |
file_uri=uploaded_files[0].uri, | |
mime_type=uploaded_files[0].mime_type, | |
), | |
types.Part.from_text(text="This is the clothing image. Swap the clothing onto the person image."), | |
types.Part.from_uri( | |
file_uri=uploaded_files[1].uri, | |
mime_type=uploaded_files[1].mime_type, | |
), | |
types.Part.from_text(text=prompt), | |
types.Part.from_uri( | |
file_uri=uploaded_files[0].uri, | |
mime_type=uploaded_files[0].mime_type, | |
), | |
], | |
), | |
] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=0.099, | |
top_p=0.95, | |
top_k=40, | |
max_output_tokens=8192, | |
response_modalities=[ | |
"image", | |
"text", | |
], | |
safety_settings=[ | |
types.SafetySetting( | |
category="HARM_CATEGORY_HARASSMENT", | |
threshold=HarmBlockThreshold.BLOCK_NONE, | |
), | |
types.SafetySetting( | |
category="HARM_CATEGORY_HATE_SPEECH", | |
threshold=HarmBlockThreshold.BLOCK_NONE, | |
), | |
types.SafetySetting( | |
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
threshold=HarmBlockThreshold.BLOCK_NONE, | |
), | |
types.SafetySetting( | |
category="HARM_CATEGORY_DANGEROUS_CONTENT", | |
threshold=HarmBlockThreshold.BLOCK_NONE, | |
), | |
], | |
response_mime_type="text/plain", | |
) | |
response = client.models.generate_content( | |
model="models/gemini-2.0-flash-exp", | |
contents=contents, | |
config=generate_content_config, | |
) | |
# Add any warnings to the output text | |
if warning_list: | |
output_text += "\nWarnings:\n" | |
for warning in warning_list: | |
output_text += f"- {warning.message}\n" | |
# Process the response | |
if response and hasattr(response, 'candidates') and response.candidates: | |
candidate = response.candidates[0] | |
if hasattr(candidate, 'content') and candidate.content: | |
for part in candidate.content.parts: | |
if part.text is not None: | |
output_text += part.text + "\n" | |
elif part.inline_data is not None: | |
try: | |
if isinstance(part.inline_data.data, bytes): | |
image_data = part.inline_data.data | |
else: | |
image_data = base64.b64decode(part.inline_data.data) | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
temp_file.write(image_data) | |
temp_file_path = temp_file.name | |
output_image = Image.open(temp_file_path) | |
os.unlink(temp_file_path) | |
except Exception as img_error: | |
output_text += f"Error processing image: {str(img_error)}\n" | |
else: | |
output_text = "The model did not generate a valid response. Please try again with different images." | |
except Exception as e: | |
error_details = f"Error: {str(e)}\n\nType: {type(e).__name__}" | |
if warning_list: | |
error_details += "\n\nWarnings:\n" | |
for warning in warning_list: | |
error_details += f"- {warning.message}\n" | |
print(f"Exception occurred: {error_details}") | |
return None, error_details | |
finally: | |
# Clean up all temporary files | |
for temp_file in temp_files: | |
if os.path.exists(temp_file): | |
os.unlink(temp_file) | |
# Clean up any uploaded files if possible | |
for uploaded_file in uploaded_files: | |
try: | |
if hasattr(client.files, 'delete') and uploaded_file: | |
client.files.delete(uploaded_file.uri) | |
except: | |
pass # Best effort cleanup | |
# Clear the client | |
client = None | |
return output_image, output_text | |
# Create the Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="Virtual Clothing Try-On") as app: | |
gr.Markdown("# Virtual Clothing Try-On") | |
gr.Markdown("Upload a photo of yourself and a photo of clothing you'd like to try on!") | |
with gr.Row(): | |
with gr.Column(): | |
person_image = gr.Image(label="Your Photo", type="pil", image_mode="RGB") | |
clothing_image = gr.Image(label="Clothing Photo", type="pil", image_mode="RGB") | |
submit_btn = gr.Button("Generate") | |
with gr.Column(): | |
output_image = gr.Image(label="Result", type="pil") | |
output_text = gr.Textbox(label="Response", lines=3) | |
submit_btn.click( | |
fn=swap_clothing, | |
inputs=[person_image, clothing_image], | |
outputs=[output_image, output_text] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch() |