File size: 6,996 Bytes
1a2d550
 
 
 
 
e8d1c65
1a2d550
 
f5572da
 
 
1a2d550
 
1b6c50f
 
 
1a2d550
e6da377
1b6c50f
 
f5572da
1b6c50f
 
1a2d550
 
1b6c50f
 
1a2d550
1b6c50f
1a2d550
 
 
1b6c50f
1a2d550
 
 
 
 
 
1b6c50f
 
 
 
 
 
 
1a2d550
 
 
 
 
1b6c50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a2d550
1b6c50f
3685e3d
1b6c50f
1a2d550
 
 
 
 
 
cbd4c84
1a2d550
1b6c50f
 
1a2d550
 
3685e3d
 
 
 
 
 
1b6c50f
 
e6da377
47f9727
1a2d550
47f9727
83ecd2f
e6da377
47f9727
e6da377
 
 
 
1b6c50f
 
1a2d550
1b6c50f
 
 
 
e6da377
1b6c50f
83ecd2f
4c18034
e6da377
 
1b6c50f
e6da377
1b6c50f
 
e6da377
 
6c4c5f5
 
1b6c50f
47f9727
6c4c5f5
1a2d550
1b6c50f
 
 
ed4625b
 
662515a
1b6c50f
 
 
 
 
 
 
 
 
 
844ed4d
 
 
1b6c50f
 
19ffd31
bc7c0d9
1b6c50f
 
 
 
bc7c0d9
 
 
1b6c50f
 
 
 
bc7c0d9
 
1b6c50f
bc7c0d9
1b6c50f
 
ed4625b
1b6c50f
 
ff6bd0c
ed4625b
e6da377
1b6c50f
ed4625b
1b6c50f
 
ed4625b
 
e6da377
1b6c50f
19ffd31
ac3fd77
 
1b6c50f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import base64
import os
import mimetypes
from google import genai
from google.genai import types
import gradio as gr
import io
from PIL import Image
from huggingface_hub import login

GEMINI_KEY = os.environ.get("GEMINI_API_KEY")

def save_binary_file(file_name, data):
    f = open(file_name, "wb")
    f.write(data)
    f.close()

def generate_image(prompt, image=None):
    # Initialize client with the API key
    client = genai.Client(
        api_key=GEMINI_KEY,
    )

    model = "gemini-2.0-flash-exp-image-generation"
    parts = [types.Part.from_text(text=prompt)]

    # If an image is provided, add it to the content
    if image:
        # Convert PIL Image to bytes
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format="PNG")
        img_bytes = img_byte_arr.getvalue()
        # Add the image as a Part with inline_data
        parts.append({
            "inline_data": {
                "mime_type": "image/png",
                "data": img_bytes
            }
        })

    contents = [
        types.Content(
            role="user",
            parts=parts,
        ),
    ]
    generate_content_config = types.GenerateContentConfig(
        temperature=1,
        top_p=0.95,
        top_k=40,
        max_output_tokens=8192,
        response_modalities=[
            "image",
            "text",
        ],
        safety_settings=[
            types.SafetySetting(
                category="HARM_CATEGORY_CIVIC_INTEGRITY",
                threshold="OFF",
            ),
        ],
        response_mime_type="text/plain",
    )

    # Generate the content
    response = client.models.generate_content_stream(
        model=model,
        contents=contents,
        config=generate_content_config,
    )

    full_text_response = "" # For debugging text truncation
    # Process the response
    for chunk in response:
        if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
            continue
        if chunk.candidates[0].content.parts[0].inline_data:
            inline_data = chunk.candidates[0].content.parts[0].inline_data
            file_extension = mimetypes.guess_extension(inline_data.mime_type)
            filename = f"generated_image{file_extension}" # Hardcoded filename
            save_binary_file(filename, inline_data.data)

            # Convert binary data to PIL Image
            img = Image.open(io.BytesIO(inline_data.data))
            return img, f"Image saved as {filename}"
        elif chunk.text:
            full_text_response += chunk.text # Append chunk text for full response
            print("Chunk Text Response:", chunk.text) # Debugging chunk text

    print("Full Text Response from Gemini:", full_text_response) # Debugging full text
    return None, full_text_response # Return full text response

# Function to handle chat interaction
def chat_handler(prompt, user_image, chat_history):
    # Add the user prompt to the chat history - ONLY TEXT PROMPT for user message
    if prompt:
        chat_history.append({"role": "user", "content": prompt})
    if user_image is not None:
        # If there's a user image, add a separate message for the high-quality image in a smaller container
        buffered = io.BytesIO()
        user_image.save(buffered, format="PNG")
        user_image_base64 = base64.b64encode(buffered.getvalue()).decode()
        user_image_data_uri = f"data:image/png;base64,{user_image_base64}"
        chat_history.append({"role": "user", "content": gr.HTML(f'<img src="{user_image_data_uri}" alt="Uploaded Image" style="width:100px; height:100px; object-fit:contain;">')})

    # If no input, return early
    if not prompt and not user_image:
        chat_history.append({"role": "assistant", "content": "Please provide a prompt or an image."})
        return chat_history, user_image, None, ""

    # Generate image based on user input
    img, status = generate_image(prompt or "Generate an image", user_image)

    thumbnail_data_uri = None # Initialize to None
    if img:
        # Use full-resolution image in a smaller container
        img = img.convert("RGB") # Force RGB mode for consistency
        buffered = io.BytesIO()
        img.save(buffered, format="PNG")
        thumbnail_base64 = base64.b64encode(buffered.getvalue()).decode()
        thumbnail_data_uri = f"data:image/png;base64,{thumbnail_base64}"
        print("Image Data URI:", thumbnail_data_uri) # Print to console
        assistant_message_content = gr.HTML(f'<img src="{thumbnail_data_uri}" alt="Generated Image" style="width:100px; height:100px; object-fit:contain;">') # Use gr.HTML with CSS
    else:
        assistant_message_content = status # If no image, send text status

    # Update chat history - Assistant message is now EITHER gr.HTML or text
    chat_history.append({"role": "assistant", "content": assistant_message_content})

    return chat_history, user_image, img, ""

# Create Gradio interface
with gr.Blocks(title="Image Editing Chatbot") as demo:
    gr.Markdown("# Image Editing Chatbot")
    gr.Markdown("Upload an image and/or type a prompt to generate or edit an image using Google's Gemini model")

    # Chatbot display area for text messages
    chatbot = gr.Chatbot(
        label="Chat",
        height=300,
        type="messages",
        avatar_images=(None, None)
    )

    # Separate image outputs
    with gr.Row():
        uploaded_image_output = gr.Image(label="Uploaded Image")
        generated_image_output = gr.Image(label="Generated Image")

    # Input area
    with gr.Row():
        with gr.Column(scale=2):  # Increased scale for better spacing
            image_input = gr.Image(
                label="Upload Image",
                type="pil",
                scale=1,
                height=150,  # Increased height for better visibility
                container=True,  # Ensure the component has a container for padding
                elem_classes="p-4"  # Add padding via CSS class (4 units of padding)
            )
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="Enter your image description here...",
                lines=3,
                elem_classes="mt-2"  # Add margin-top for spacing between components
            )
            generate_btn = gr.Button("Generate Image", elem_classes="mt-2")  # Add margin-top for the button

    # State to maintain chat history
    chat_state = gr.State([])

    # Connect the button to the chat handler
    generate_btn.click(
        fn=chat_handler,
        inputs=[prompt_input, image_input, chat_state],
        outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
    )

    # Also allow Enter key to submit
    prompt_input.submit(
        fn=chat_handler,
        inputs=[prompt_input, image_input, chat_state],
        outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
    )

if __name__ == "__main__":
    demo.launch()