Spaces:
Sleeping
Sleeping
import gradio as gr | |
import google.genai as genai | |
from google.genai import types | |
from PIL import Image | |
import os | |
import textract | |
# List of available models (including experimental and recent ones) | |
models = [ | |
"gemini-2.5-flash-preview-04-17", | |
"gemini-2.5-pro-preview-03-25", | |
"gemini-2.0-flash", | |
"gemini-2.0-flash-lite", | |
"gemini-2.0-flash-thinking-exp-01-21", | |
"gemini-1.5-pro", | |
"gemini-2.0-flash-exp-image-generation" | |
] | |
# Model types for handling inputs | |
model_types = { | |
"gemini-2.5-flash-preview-04-17": "text", | |
"gemini-2.5-pro-preview-03-25": "text", | |
"gemini-2.0-flash": "text", | |
"gemini-2.0-flash-lite": "text", | |
"gemini-2.0-flash-thinking-exp-01-21": "text", | |
"gemini-1.5-pro": "text", | |
"gemini-2.0-flash-exp-image-generationn": "multimodal" | |
} | |
# Function to validate API key | |
def validate_api_key(api_key): | |
try: | |
client = genai.Client(api_key=api_key) | |
client.models.list() # Validate by attempting to list models | |
return True, "API Key is valid." | |
except Exception as e: | |
return False, f"Invalid API Key: {str(e)}" | |
# Function to process uploaded files | |
def process_files(files, model_type): | |
inputs = [] | |
for file_path in files: | |
if model_type == "multimodal" and file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): | |
img = Image.open(file_path) | |
inputs.append(img) | |
else: | |
try: | |
text = textract.process(file_path).decode('utf-8') | |
inputs.append(text) | |
except Exception as e: | |
inputs.append(f"Error extracting text from {os.path.basename(file_path)}: {str(e)}") | |
return inputs | |
# Chat submit function | |
def chat_submit_func(message, files, chat_history, model, temperature, top_p, max_tokens, api_key): | |
print(model) | |
client = genai.Client(api_key=api_key) | |
# Prepare inputs | |
if model_types[model] == "text" and files: | |
chat_history.append((message, "Warning: Files are not supported for text-only models. Converting to text where possible.")) | |
processed_inputs = process_files(files, "text") | |
inputs = [message] + processed_inputs | |
else: | |
processed_inputs = process_files(files, model_types[model]) if files else [] | |
inputs = [message] + processed_inputs | |
# Generation configuration | |
generation_config = { | |
"temperature": temperature, | |
"top_p": top_p, | |
"max_output_tokens": max_tokens, | |
} | |
try: | |
response = client.models.generate_content(inputs, model=model, config=generation_config) | |
response_text = "" | |
response_images = [] | |
# Parse response | |
for candidate in response.candidates: | |
for part in candidate.content.parts: | |
if hasattr(part, 'text') and part.text: | |
response_text += part.text | |
elif hasattr(part, 'file_data') and part.file_data: | |
# Assuming file_data provides a URL; adjust if base64 or other format | |
image_url = part.file_data.url | |
response_images.append(image_url) | |
# Update chat history | |
user_message = message | |
if files: | |
user_message += "\nFiles: " + ", ".join([os.path.basename(f) for f in files]) | |
chat_history.append((user_message, None)) | |
bot_message = response_text | |
if response_images: | |
bot_message += "\n" + "\n".join([f"" for img in response_images]) | |
chat_history.append((None, bot_message)) | |
return chat_history, "" | |
except Exception as e: | |
chat_history.append((message, f"Error: {str(e)}")) | |
return chat_history, "" | |
# Single response submit function | |
def single_submit_func(prompt, files, model, temperature, top_p, max_tokens, api_key): | |
print(model) | |
client = genai.Client(api_key=api_key) | |
# Prepare inputs | |
if model_types[model] == "text" and files: | |
processed_inputs = process_files(files, "text") | |
inputs = [prompt] + processed_inputs | |
warning = "Warning: Files converted to text for text-only model." | |
else: | |
processed_inputs = process_files(files, model_types[model]) if files else [] | |
inputs = [prompt] + processed_inputs | |
warning = "" | |
contents = [ | |
types.Content( | |
role="user", | |
parts=[ | |
types.Part.from_text(text=prompt), | |
], | |
), | |
] | |
generate_content_config = types.GenerateContentConfig( | |
response_mime_type="text/plain", | |
) | |
try: | |
response = client.models.generate_content(model=model, contents=contents, config=generate_content_config) | |
response_text = warning | |
response_images = [] | |
# Parse response | |
for candidate in response.candidates: | |
for part in candidate.content.parts: | |
if hasattr(part, 'text') and part.text: | |
response_text += part.text | |
elif hasattr(part, 'file_data') and part.file_data: | |
image_url = part.file_data.url | |
response_images.append(image_url) | |
return response_text, response_images | |
except Exception as e: | |
return f"Error: {str(e)}", [] | |
# Gradio interface | |
with gr.Blocks(title="Gemini API Interface") as app: | |
# API Key Section | |
api_key_input = gr.Textbox(label="Gemini API Key", type="password", placeholder="Enter your Gemini API Key") | |
validate_btn = gr.Button("Validate API Key") | |
key_status = gr.Textbox(label="API Key Status", interactive=False) | |
key_validated = gr.State(False) | |
# Model and Parameters Section (hidden until key is validated) | |
with gr.Group(visible=False) as config_group: | |
model_selector = gr.Dropdown(choices=models, label="Select Model", value=models[0]) | |
temperature = gr.Slider(0, 1, value=0.7, label="Temperature", step=0.01) | |
top_p = gr.Slider(0, 1, value=0.9, label="Top P", step=0.01) | |
max_tokens = gr.Number(value=512, label="Max Tokens", minimum=1) | |
# Tabs for Chat and Single Response (hidden until key is validated) | |
with gr.Tabs(visible=False) as tabs: | |
with gr.TabItem("Chat"): | |
chat_display = gr.Chatbot(label="Chat History") | |
chat_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") | |
chat_files = gr.File(label="Upload Files", file_count="multiple") | |
chat_submit_btn = gr.Button("Send") | |
chat_status = gr.Textbox(label="Status", interactive=False) | |
with gr.TabItem("Single Response"): | |
single_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") | |
single_files = gr.File(label="Upload Files", file_count="multiple") | |
single_submit_btn = gr.Button("Generate") | |
single_text_output = gr.Textbox(label="Response Text", interactive=False) | |
single_image_output = gr.Gallery(label="Response Images") | |
# Validation logic | |
def on_validate_key(api_key): | |
is_valid, status = validate_api_key(api_key) | |
if is_valid: | |
return status, True, gr.update(visible=True), gr.update(visible=True) | |
return status, False, gr.update(visible=False), gr.update(visible=False) | |
validate_btn.click( | |
on_validate_key, | |
inputs=[api_key_input], | |
outputs=[key_status, key_validated, config_group, tabs] | |
) | |
# Chat submission | |
chat_submit_btn.click( | |
chat_submit_func, | |
inputs=[chat_input, chat_files, chat_display, model_selector, temperature, top_p, max_tokens, api_key_input], | |
outputs=[chat_display, chat_status] | |
) | |
# Single response submission | |
single_submit_btn.click( | |
single_submit_func, | |
inputs=[single_input, single_files, model_selector, temperature, top_p, max_tokens, api_key_input], | |
outputs=[single_text_output, single_image_output] | |
) | |
app.launch() |