iimran's picture
Update app.py
9a3f2d8 verified
import os
import tempfile
import logging
import requests
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Tuple, List, Dict, Any, Union
from io import BytesIO
from PIL import Image
import gradio as gr
from google import genai
from google.genai import types
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('image_transformer')
@dataclass
class TransformResult:
"""Class to hold the result of an image transformation"""
image_path: Optional[str] = None
text_output: str = ""
success: bool = True
error_message: str = ""
class ImageTransformer:
"""Class to handle image transformation via Gemini API"""
def __init__(self, model_name: str = "gemini-2.0-flash-exp"):
self.model_name = model_name
logger.info(f"ImageTransformer initialized with model: {model_name}")
def write_binary_data(self, filepath: str, data: bytes) -> None:
"""Write binary data to a file"""
try:
with open(filepath, "wb") as f:
f.write(data)
logger.info(f"Successfully wrote data to {filepath}")
except Exception as e:
logger.error(f"Failed to write data to {filepath}: {e}")
raise
def initialize_client(self, api_key: str) -> genai.Client:
"""Initialize the Gemini API client"""
if not api_key or api_key.strip() == "":
# Use environment variable if no API key provided
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
logger.error("No API key provided and GEMINI_API_KEY not found in environment")
raise ValueError("API key is required. Either provide one or set GEMINI_API_KEY environment variable.")
logger.info("Initializing Gemini client")
return genai.Client(api_key=api_key.strip())
def create_request_content(self, file_data: Dict[str, Any], instruction_text: str) -> List[types.Content]:
"""Create the content object for the API request"""
logger.info(f"Creating request content with instruction: {instruction_text}")
return [
types.Content(
role="user",
parts=[
types.Part.from_uri(
file_uri=file_data["uri"],
mime_type=file_data["mime_type"],
),
types.Part.from_text(text=instruction_text),
],
),
]
def create_request_config(self) -> types.GenerateContentConfig:
"""Create the configuration for the API request"""
logger.info("Creating request configuration")
return types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=["image", "text"],
response_mime_type="text/plain",
)
def transform_image(self, input_image_path: str, instruction: str, api_key: str) -> TransformResult:
"""Transform an image based on the given instruction using Gemini API"""
result = TransformResult()
try:
# Initialize client
client = self.initialize_client(api_key)
# Upload the file
logger.info(f"Uploading file: {input_image_path}")
uploaded_file = client.files.upload(file=input_image_path)
# Create content and configuration for request
contents = self.create_request_content(
{"uri": uploaded_file.uri, "mime_type": uploaded_file.mime_type},
instruction
)
config = self.create_request_config()
# Create a temporary file for the response
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
output_path = tmp.name
logger.info(f"Created temporary output file: {output_path}")
# Send request and process response stream
logger.info("Sending request to Gemini API")
response_stream = client.models.generate_content_stream(
model=self.model_name,
contents=contents,
config=config,
)
# Process the response stream
for chunk in response_stream:
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
continue
candidate = chunk.candidates[0].content.parts[0]
# Handle image data
if candidate.inline_data:
logger.info(f"Received image data ({candidate.inline_data.mime_type})")
self.write_binary_data(output_path, candidate.inline_data.data)
result.image_path = output_path
break
# Handle text data
else:
result.text_output += chunk.text + "\n"
# Clean up
logger.info("Cleanup: removing uploaded file reference")
del uploaded_file
# If we have text output but no image, log it
if not result.image_path and result.text_output:
logger.info(f"No image generated. Text output: {result.text_output[:100]}...")
return result
except Exception as e:
logger.error(f"Error in transform_image: {e}")
result.success = False
result.error_message = str(e)
return result
def process_request(self, input_image, instruction: str, api_key: str) -> Tuple[List[Image.Image], str]:
"""Process a user request to transform an image"""
try:
# Check inputs
if input_image is None:
return None, "Please upload an image to transform."
if not instruction or instruction.strip() == "":
return None, "Please provide transformation instructions."
# Handle both uploaded images and URL examples
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
input_path = tmp.name
# Check if input_image is a PIL Image or a string (URL)
if isinstance(input_image, str) and (input_image.startswith('http://') or input_image.startswith('https://')):
# It's a URL from an example
import requests
from io import BytesIO
logger.info(f"Downloading image from URL: {input_image}")
response = requests.get(input_image, stream=True, timeout=10)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
img.save(input_path)
logger.info(f"Saved downloaded image to temporary file: {input_path}")
else:
# It's a PIL Image from user upload
input_image.save(input_path)
logger.info(f"Saved uploaded image to temporary file: {input_path}")
# Transform the image
result = self.transform_image(input_path, instruction, api_key)
# Handle result
if not result.success:
return None, f"Error: {result.error_message}"
if result.image_path:
# Load and convert the result image
output_image = Image.open(result.image_path)
if output_image.mode == "RGBA":
output_image = output_image.convert("RGB")
logger.info(f"Successfully processed image: {result.image_path}")
return [output_image], ""
else:
# Return the text response if no image was generated
logger.info("No image generated, returning text response")
return None, result.text_output or "No output generated. Try adjusting your instructions."
except Exception as e:
logger.error(f"Error in process_request: {e}")
return None, f"Error: {str(e)}"
def build_ui() -> gr.Blocks:
"""Build the Gradio interface"""
logger.info("Building UI")
# Create transformer instance
transformer = ImageTransformer()
# Custom CSS
custom_css = """
/* Main theme colors */
:root {
--primary-color: #3a506b;
--secondary-color: #5bc0be;
--accent-color: #ffd166;
--background-color: #f8f9fa;
--text-color: #1c2541;
--border-radius: 8px;
--box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
}
/* Global styles */
body {
font-family: 'Inter', system-ui, -apple-system, BlinkMacSystemFont, sans-serif;
background-color: var(--background-color);
color: var(--text-color);
}
/* Header styling */
.app-header {
display: flex;
align-items: center;
gap: 20px;
padding: 16px 24px;
background: linear-gradient(135deg, var(--primary-color), #1c2541);
color: white;
border-radius: var(--border-radius);
margin-bottom: 24px;
box-shadow: var(--box-shadow);
}
.app-header img {
width: 48px;
height: 48px;
border-radius: 50%;
background-color: white;
padding: 6px;
}
.app-header h1 {
margin: 0;
font-size: 1.8rem;
font-weight: 700;
}
.app-header p {
margin: 4px 0 0 0;
opacity: 0.9;
font-size: 0.9rem;
}
.app-header a {
color: var(--accent-color);
text-decoration: none;
transition: opacity 0.2s;
}
.app-header a:hover {
opacity: 0.8;
text-decoration: underline;
}
/* Accordion styling */
.accordion-container {
margin-bottom: 20px;
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius);
overflow: hidden;
}
.accordion-header {
background-color: var(--primary-color);
color: white;
padding: 12px 16px;
font-weight: 600;
}
.accordion-content {
padding: 16px;
background-color: white;
}
/* Main content area */
.main-container {
display: flex;
gap: 24px;
margin-bottom: 24px;
}
/* Input column */
.input-column {
flex: 1;
background-color: white;
padding: 20px;
border-radius: var(--border-radius);
box-shadow: var(--box-shadow);
}
/* Output column */
.output-column {
flex: 1;
background-color: white;
padding: 20px;
border-radius: var(--border-radius);
box-shadow: var(--box-shadow);
}
/* Button styling */
.generate-button {
background-color: var(--secondary-color) !important;
color: white !important;
border: none !important;
border-radius: var(--border-radius) !important;
padding: 12px 24px !important;
font-weight: 600 !important;
cursor: pointer !important;
transition: background-color 0.2s !important;
width: 100% !important;
margin-top: 16px !important;
}
.generate-button:hover {
background-color: #4ca8a6 !important;
}
/* Image upload area */
.image-upload {
border: 2px dashed rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius);
padding: 20px;
text-align: center;
transition: border-color 0.2s;
}
.image-upload:hover {
border-color: var(--secondary-color);
}
/* Input fields */
input[type="text"], input[type="password"], textarea {
width: 100%;
padding: 10px 12px;
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius);
margin-bottom: 16px;
font-family: inherit;
}
input[type="text"]:focus, input[type="password"]:focus, textarea:focus {
border-color: var(--secondary-color);
outline: none;
}
/* Examples section */
.examples-header {
margin: 32px 0 16px 0;
font-weight: 600;
color: var(--primary-color);
}
/* Footer */
.app-footer {
text-align: center;
padding: 16px;
margin-top: 32px;
color: rgba(0, 0, 0, 0.5);
font-size: 0.8rem;
}
"""
# Gradio interface
with gr.Blocks(css=custom_css) as app:
# Header
gr.HTML(
"""
<div class="app-header">
<div>
<img src="https://img.icons8.com/fluency/96/000000/paint-3d.png" alt="App logo">
</div>
<div>
<h1>ImageWizard</h1>
<p>Transform images with AI | <a href="https://aistudio.google.com/apikey">Get API Key</a></p>
</div>
</div>
"""
)
# API key information
with gr.Accordion("🔑 API Key Required", open=True):
gr.HTML(
"""
<div class="accordion-content">
<p><strong>You need a Gemini API key to use this application.</strong></p>
<ol>
<li>Visit <a href="https://aistudio.google.com/apikey" target="_blank">Google AI Studio</a> to get your free API key</li>
<li>Enter the key in the API Key field below</li>
<li>Your key is never stored and only sent directly to Google's API</li>
</ol>
</div>
"""
)
# Usage instructions
with gr.Accordion("📝 How To Use", open=False):
gr.HTML(
"""
<div class="accordion-content">
<h3>How to transform your images:</h3>
<ol>
<li><strong>Upload an Image:</strong> Click the upload area to select an image (PNG or JPG recommended)</li>
<li><strong>Enter your API Key:</strong> Paste your Gemini API key in the designated field</li>
<li><strong>Write Instructions:</strong> Clearly describe how you want to transform the image</li>
<li><strong>Generate:</strong> Click the Transform button and wait for results</li>
</ol>
<p><strong>Tips for better results:</strong></p>
<ul>
<li>Be specific with your instructions (e.g., "change the background to a beach scene" rather than "change the background")</li>
<li>If you get text instead of an image, try rephrasing your instructions</li>
<li>For best results, use images with clear subjects and simple backgrounds</li>
</ul>
<p><strong>Please Note:</strong> Do not upload or generate inappropriate content</p>
</div>
"""
)
# Main container
with gr.Row(elem_classes="main-container"):
# Input column
with gr.Column(elem_classes="input-column"):
image_input = gr.Image(
type="pil",
label="Upload Your Image",
image_mode="RGBA",
elem_classes="image-upload"
)
api_key_input = gr.Textbox(
lines=1,
placeholder="Enter your Gemini API Key here",
label="Gemini API Key",
type="password"
)
instruction_input = gr.Textbox(
lines=3,
placeholder="Describe how you want to transform the image...",
label="Transformation Instructions"
)
transform_btn = gr.Button("Transform Image", variant="primary")
# Output column
with gr.Column(elem_classes="output-column"):
output_gallery = gr.Gallery(
label="Transformed Image",
elem_classes="gallery-container"
)
output_text = gr.Textbox(
label="Text Output",
placeholder="If no image is generated, text output will appear here.",
elem_classes="text-output"
)
# Set up the interaction
transform_btn.click(
fn=transformer.process_request,
inputs=[image_input, instruction_input, api_key_input],
outputs=[output_gallery, output_text],
)
# Examples section
gr.Markdown("## Try These Examples", elem_classes="examples-header")
# Examples using publicly available images (Pexels, Unsplash, etc.)
examples = [
["https://images.pexels.com/photos/268533/pexels-photo-268533.jpeg", "Change this landscape to night time with stars", ""],
["https://images.pexels.com/photos/1933873/pexels-photo-1933873.jpeg", "Add text that says 'DREAM BIG' in elegant font", ""],
["https://images.pexels.com/photos/1629781/pexels-photo-1629781.jpeg", "Remove the person from this photo", ""],
["https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg", "Make this dog look like it's wearing a superhero cape", ""],
["https://images.unsplash.com/photo-1555396273-367ea4eb4db5", "Add a neon glow effect around the coffee cup", ""],
["https://images.unsplash.com/photo-1501504905252-473c47e087f8", "Make this whiteboard text more legible and colorful", ""],
]
gr.Examples(
examples=examples,
inputs=[image_input, instruction_input]
)
# Footer
gr.HTML(
"""
<div style="text-align: center; padding: 16px; margin-top: 32px; color: rgba(0, 0, 0, 0.5); font-size: 0.8rem;">
<p>ImageWizard © 2025 | Powered by Google Gemini and Gradio</p>
</div>
"""
)
return app
# Main application entry point
def main():
logger.info("Starting Image Transformer application")
app = build_ui()
app.queue(max_size=50).launch()
if __name__ == "__main__":
main()