Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import base64 | |
from io import BytesIO | |
import re | |
import os | |
examples = [ | |
{"image": "./assets/example_desktop.png", "prompt": "switch off the wired connection"}, | |
{"image": "./assets/example_web.png", "prompt": "view all branches"}, | |
{"image": "./assets/example_mobile.jpg", "prompt": "share the screenshot"}, | |
] | |
# Code from user | |
openai_api_key = os.environ["aria_ui_api_key"] | |
openai_api_base = os.environ["aria_ui_api_base"] | |
from openai import OpenAI # Assuming the OpenAI client library is installed | |
client = OpenAI( | |
api_key=openai_api_key, | |
base_url=openai_api_base, | |
) | |
models = client.models.list() | |
model = models.data[0].id | |
def encode_pil_image_to_base64(image: Image.Image) -> str: | |
image = image.convert("RGB") | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return img_str | |
def request_aria_ui(image: Image.Image, prompt: str) -> str: | |
image_base64 = encode_pil_image_to_base64(image) | |
chat_completion_from_url = client.chat.completions.create( | |
messages=[{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "<image>Given a GUI image, what are the relative (0-1000) pixel point coordinates for the element corresponding to the following instruction or description: " + prompt | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{image_base64}" | |
}, | |
}, | |
], | |
}], | |
model=model, | |
max_tokens=512, | |
stop=["<|im_end|>"], | |
extra_body={"split_image": True, "image_max_size": 980, "temperature": 0, "top_k": 1} | |
) | |
result = chat_completion_from_url.choices[0].message.content | |
return result | |
def _extract_coords_from_response(response: str) -> tuple[int, int]: | |
resp = response.replace("```", "").strip() | |
numbers = re.findall(r'\d+', resp) | |
if len(numbers) != 2: | |
raise ValueError(f"Expected exactly 2 coordinates, found {len(numbers)} numbers in response: {response}") | |
return int(numbers[0]), int(numbers[1]) | |
def image_grounding(image: Image.Image, prompt: str) -> Image.Image: | |
try: | |
# Request processing from API | |
response = request_aria_ui(image, prompt) | |
# Extract normalized coordinates | |
norm_coords = _extract_coords_from_response(response) | |
# Convert normalized coordinates to absolute coordinates | |
width, height = image.size | |
long_side = max(width, height) | |
abs_coords = ( | |
int(norm_coords[0] * width / 1000), # Scale x-coordinate | |
int(norm_coords[1] * height / 1000) # Scale y-coordinate | |
) | |
# Load and prepare the click indicator image | |
click_image = Image.open("assets/click.png") | |
# Calculate adaptive size for click indicator | |
# Make it proportional to the image width (e.g., 3% of image width) | |
target_width = int(long_side * 0.03) # 3% of image width | |
aspect_ratio = click_image.width / click_image.height | |
target_height = int(target_width / aspect_ratio) | |
click_image = click_image.resize((target_width, target_height)) | |
# Calculate position to center the click image on the coordinates | |
# Add a small offset downward (20% of click image height) | |
# Calculate position to align the 30% point of the click image with the coordinates | |
click_x = abs_coords[0] - int(click_image.width * 0.3) # Align 30% from left | |
click_y = abs_coords[1] - int(click_image.height * 0.3) # Align 30% from top | |
# Create output image and paste the click indicator | |
output_image = image.copy() | |
# Draw bounding box | |
draw = ImageDraw.Draw(output_image) | |
bbox = [ | |
click_x, # left | |
click_y, # top | |
click_x + click_image.width, # right | |
click_y + click_image.height # bottom | |
] | |
draw.rectangle(bbox, outline='red', width=int(click_image.width * 0.1)) | |
output_image.paste(click_image, (click_x, click_y), click_image) | |
return output_image | |
except Exception as e: | |
raise ValueError(f"An error occurred: {e}") | |
def resize_image_with_max_size(image: Image.Image, max_size: int = 1920) -> Image.Image: | |
"""Resize image to have a maximum dimension of max_size while maintaining aspect ratio.""" | |
width, height = image.size | |
if width <= max_size and height <= max_size: | |
return image | |
if width > height: | |
new_width = max_size | |
new_height = int(height * (max_size / width)) | |
else: | |
new_height = max_size | |
new_width = int(width * (max_size / height)) | |
return image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
# Gradio app | |
def gradio_interface(input_image, prompt): | |
print(input_image.size) | |
input_image = resize_image_with_max_size(input_image) | |
print(input_image.size) | |
output_image = image_grounding(input_image, prompt) | |
return output_image | |
with gr.Blocks() as demo: | |
# with gr.Row(elem_classes="container"): | |
# gr.Image("https://raw.githubusercontent.com/AriaUI/Aria-UI/refs/heads/main/assets/logo_long.png", show_label=False, container=False, scale=1, elem_classes="logo", height=76) | |
gr.HTML( | |
""" | |
<div style="text-align: center; margin-bottom: 20px;"> | |
<div style="display: flex; justify-content: center;"> | |
<img src="https://raw.githubusercontent.com/AriaUI/Aria-UI/refs/heads/main/assets/logo_long.png" alt="Aria-UI" style="height: 76px; margin-bottom: 10px;"/> | |
</div> | |
</div> | |
""" | |
) | |
gr.Markdown("""| [π€ Aria-UI Models](https://huggingface.co/Aria-UI/Aria-UI-base) β’ [π€ Aria-UI Dataset](https://huggingface.co/datasets/Aria-UI/Aria-UI_Data) β’ [π Project Page](https://ariaui.github.io) β’ [π Paper](https://arxiv.org/abs/2412.16256) | | |
|:---------------------------------------------------------------------------------------------------------:|""") | |
gr.Markdown("# Aria-UI: Visual Grounding for GUI Instructions") | |
gr.Markdown("ππ Upload a GUI image and enter a instruction. Aria-UI will try its best to ground the instruction to specific element in the image. π―π―") | |
with gr.Row(): | |
with gr.Column(scale=2): # Make this column smaller | |
image_input = gr.Image(type="pil", label="Upload GUI Image", height=600) | |
prompt_input = gr.Textbox(label="Enter GUI Instruction") | |
submit_button = gr.Button("Process") | |
with gr.Column(scale=3): # Make this column larger | |
output_image = gr.Image(label="Grounding Result", height=500) # Set specific height for larger display | |
with gr.Column(scale=2): | |
# Move examples here and make them vertical | |
gr.Examples( | |
examples=[ | |
[ | |
example["image"], | |
example["prompt"] | |
] | |
for example in examples | |
], | |
inputs=[image_input, prompt_input], | |
outputs=[output_image], | |
fn=gradio_interface, | |
cache_examples=False, | |
label="Example Tasks", # Add label for better organization | |
examples_per_page=5 # Control number of examples shown at once | |
) | |
submit_button.click( | |
fn=gradio_interface, | |
inputs=[image_input, prompt_input], | |
outputs=[output_image] | |
) | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
ssr_mode=False, | |
debug=True, | |
) |