gemini-image / app.py
Deadmon's picture
Update app.py
bc7c0d9 verified
import base64
import os
import mimetypes
from google import genai
from google.genai import types
import gradio as gr
import io
from PIL import Image
from huggingface_hub import login
GEMINI_KEY = os.environ.get("GEMINI_API_KEY")
def save_binary_file(file_name, data):
f = open(file_name, "wb")
f.write(data)
f.close()
def generate_image(prompt, image=None):
# Initialize client with the API key
client = genai.Client(
api_key=GEMINI_KEY,
)
model = "gemini-2.0-flash-exp-image-generation"
parts = [types.Part.from_text(text=prompt)]
# If an image is provided, add it to the content
if image:
# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
img_bytes = img_byte_arr.getvalue()
# Add the image as a Part with inline_data
parts.append({
"inline_data": {
"mime_type": "image/png",
"data": img_bytes
}
})
contents = [
types.Content(
role="user",
parts=parts,
),
]
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=[
"image",
"text",
],
safety_settings=[
types.SafetySetting(
category="HARM_CATEGORY_CIVIC_INTEGRITY",
threshold="OFF",
),
],
response_mime_type="text/plain",
)
# Generate the content
response = client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
)
full_text_response = "" # For debugging text truncation
# Process the response
for chunk in response:
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
continue
if chunk.candidates[0].content.parts[0].inline_data:
inline_data = chunk.candidates[0].content.parts[0].inline_data
file_extension = mimetypes.guess_extension(inline_data.mime_type)
filename = f"generated_image{file_extension}" # Hardcoded filename
save_binary_file(filename, inline_data.data)
# Convert binary data to PIL Image
img = Image.open(io.BytesIO(inline_data.data))
return img, f"Image saved as {filename}"
elif chunk.text:
full_text_response += chunk.text # Append chunk text for full response
print("Chunk Text Response:", chunk.text) # Debugging chunk text
print("Full Text Response from Gemini:", full_text_response) # Debugging full text
return None, full_text_response # Return full text response
# Function to handle chat interaction
def chat_handler(prompt, user_image, chat_history):
# Add the user prompt to the chat history - ONLY TEXT PROMPT for user message
if prompt:
chat_history.append({"role": "user", "content": prompt})
if user_image is not None:
# If there's a user image, add a separate message for the high-quality image in a smaller container
buffered = io.BytesIO()
user_image.save(buffered, format="PNG")
user_image_base64 = base64.b64encode(buffered.getvalue()).decode()
user_image_data_uri = f"data:image/png;base64,{user_image_base64}"
chat_history.append({"role": "user", "content": gr.HTML(f'<img src="{user_image_data_uri}" alt="Uploaded Image" style="width:100px; height:100px; object-fit:contain;">')})
# If no input, return early
if not prompt and not user_image:
chat_history.append({"role": "assistant", "content": "Please provide a prompt or an image."})
return chat_history, user_image, None, ""
# Generate image based on user input
img, status = generate_image(prompt or "Generate an image", user_image)
thumbnail_data_uri = None # Initialize to None
if img:
# Use full-resolution image in a smaller container
img = img.convert("RGB") # Force RGB mode for consistency
buffered = io.BytesIO()
img.save(buffered, format="PNG")
thumbnail_base64 = base64.b64encode(buffered.getvalue()).decode()
thumbnail_data_uri = f"data:image/png;base64,{thumbnail_base64}"
print("Image Data URI:", thumbnail_data_uri) # Print to console
assistant_message_content = gr.HTML(f'<img src="{thumbnail_data_uri}" alt="Generated Image" style="width:100px; height:100px; object-fit:contain;">') # Use gr.HTML with CSS
else:
assistant_message_content = status # If no image, send text status
# Update chat history - Assistant message is now EITHER gr.HTML or text
chat_history.append({"role": "assistant", "content": assistant_message_content})
return chat_history, user_image, img, ""
# Create Gradio interface
with gr.Blocks(title="Image Editing Chatbot") as demo:
gr.Markdown("# Image Editing Chatbot")
gr.Markdown("Upload an image and/or type a prompt to generate or edit an image using Google's Gemini model")
# Chatbot display area for text messages
chatbot = gr.Chatbot(
label="Chat",
height=300,
type="messages",
avatar_images=(None, None)
)
# Separate image outputs
with gr.Row():
uploaded_image_output = gr.Image(label="Uploaded Image")
generated_image_output = gr.Image(label="Generated Image")
# Input area
with gr.Row():
with gr.Column(scale=2): # Increased scale for better spacing
image_input = gr.Image(
label="Upload Image",
type="pil",
scale=1,
height=150, # Increased height for better visibility
container=True, # Ensure the component has a container for padding
elem_classes="p-4" # Add padding via CSS class (4 units of padding)
)
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your image description here...",
lines=3,
elem_classes="mt-2" # Add margin-top for spacing between components
)
generate_btn = gr.Button("Generate Image", elem_classes="mt-2") # Add margin-top for the button
# State to maintain chat history
chat_state = gr.State([])
# Connect the button to the chat handler
generate_btn.click(
fn=chat_handler,
inputs=[prompt_input, image_input, chat_state],
outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
)
# Also allow Enter key to submit
prompt_input.submit(
fn=chat_handler,
inputs=[prompt_input, image_input, chat_state],
outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
)
if __name__ == "__main__":
demo.launch()