Aria-UI / app.py
Aria-UI's picture
Update app.py
43661ec verified
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import base64
from io import BytesIO
import re
import os
examples = [
{"image": "./assets/example_desktop.png", "prompt": "switch off the wired connection"},
{"image": "./assets/example_web.png", "prompt": "view all branches"},
{"image": "./assets/example_mobile.jpg", "prompt": "share the screenshot"},
]
# Code from user
openai_api_key = os.environ["aria_ui_api_key"]
openai_api_base = os.environ["aria_ui_api_base"]
from openai import OpenAI # Assuming the OpenAI client library is installed
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
def encode_pil_image_to_base64(image: Image.Image) -> str:
image = image.convert("RGB")
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def request_aria_ui(image: Image.Image, prompt: str) -> str:
image_base64 = encode_pil_image_to_base64(image)
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role": "user",
"content": [
{
"type": "text",
"text": "<image>Given a GUI image, what are the relative (0-1000) pixel point coordinates for the element corresponding to the following instruction or description: " + prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
}],
model=model,
max_tokens=512,
stop=["<|im_end|>"],
extra_body={"split_image": True, "image_max_size": 980, "temperature": 0, "top_k": 1}
)
result = chat_completion_from_url.choices[0].message.content
return result
def _extract_coords_from_response(response: str) -> tuple[int, int]:
resp = response.replace("```", "").strip()
numbers = re.findall(r'\d+', resp)
if len(numbers) != 2:
raise ValueError(f"Expected exactly 2 coordinates, found {len(numbers)} numbers in response: {response}")
return int(numbers[0]), int(numbers[1])
def image_grounding(image: Image.Image, prompt: str) -> Image.Image:
try:
# Request processing from API
response = request_aria_ui(image, prompt)
# Extract normalized coordinates
norm_coords = _extract_coords_from_response(response)
# Convert normalized coordinates to absolute coordinates
width, height = image.size
long_side = max(width, height)
abs_coords = (
int(norm_coords[0] * width / 1000), # Scale x-coordinate
int(norm_coords[1] * height / 1000) # Scale y-coordinate
)
# Load and prepare the click indicator image
click_image = Image.open("assets/click.png")
# Calculate adaptive size for click indicator
# Make it proportional to the image width (e.g., 3% of image width)
target_width = int(long_side * 0.03) # 3% of image width
aspect_ratio = click_image.width / click_image.height
target_height = int(target_width / aspect_ratio)
click_image = click_image.resize((target_width, target_height))
# Calculate position to center the click image on the coordinates
# Add a small offset downward (20% of click image height)
# Calculate position to align the 30% point of the click image with the coordinates
click_x = abs_coords[0] - int(click_image.width * 0.3) # Align 30% from left
click_y = abs_coords[1] - int(click_image.height * 0.3) # Align 30% from top
# Create output image and paste the click indicator
output_image = image.copy()
# Draw bounding box
draw = ImageDraw.Draw(output_image)
bbox = [
click_x, # left
click_y, # top
click_x + click_image.width, # right
click_y + click_image.height # bottom
]
draw.rectangle(bbox, outline='red', width=int(click_image.width * 0.1))
output_image.paste(click_image, (click_x, click_y), click_image)
return output_image
except Exception as e:
raise ValueError(f"An error occurred: {e}")
def resize_image_with_max_size(image: Image.Image, max_size: int = 1920) -> Image.Image:
"""Resize image to have a maximum dimension of max_size while maintaining aspect ratio."""
width, height = image.size
if width <= max_size and height <= max_size:
return image
if width > height:
new_width = max_size
new_height = int(height * (max_size / width))
else:
new_height = max_size
new_width = int(width * (max_size / height))
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Gradio app
def gradio_interface(input_image, prompt):
print(input_image.size)
input_image = resize_image_with_max_size(input_image)
print(input_image.size)
output_image = image_grounding(input_image, prompt)
return output_image
with gr.Blocks() as demo:
# with gr.Row(elem_classes="container"):
# gr.Image("https://raw.githubusercontent.com/AriaUI/Aria-UI/refs/heads/main/assets/logo_long.png", show_label=False, container=False, scale=1, elem_classes="logo", height=76)
gr.HTML(
"""
<div style="text-align: center; margin-bottom: 20px;">
<div style="display: flex; justify-content: center;">
<img src="https://raw.githubusercontent.com/AriaUI/Aria-UI/refs/heads/main/assets/logo_long.png" alt="Aria-UI" style="height: 76px; margin-bottom: 10px;"/>
</div>
</div>
"""
)
gr.Markdown("""| [πŸ€— Aria-UI Models](https://huggingface.co/Aria-UI/Aria-UI-base) β€’ [πŸ€— Aria-UI Dataset](https://huggingface.co/datasets/Aria-UI/Aria-UI_Data) β€’ [🌐 Project Page](https://ariaui.github.io) β€’ [πŸ“ Paper](https://arxiv.org/abs/2412.16256) |
|:---------------------------------------------------------------------------------------------------------:|""")
gr.Markdown("# Aria-UI: Visual Grounding for GUI Instructions")
gr.Markdown("πŸš€πŸš€ Upload a GUI image and enter a instruction. Aria-UI will try its best to ground the instruction to specific element in the image. 🎯🎯")
with gr.Row():
with gr.Column(scale=2): # Make this column smaller
image_input = gr.Image(type="pil", label="Upload GUI Image", height=600)
prompt_input = gr.Textbox(label="Enter GUI Instruction")
submit_button = gr.Button("Process")
with gr.Column(scale=3): # Make this column larger
output_image = gr.Image(label="Grounding Result", height=500) # Set specific height for larger display
with gr.Column(scale=2):
# Move examples here and make them vertical
gr.Examples(
examples=[
[
example["image"],
example["prompt"]
]
for example in examples
],
inputs=[image_input, prompt_input],
outputs=[output_image],
fn=gradio_interface,
cache_examples=False,
label="Example Tasks", # Add label for better organization
examples_per_page=5 # Control number of examples shown at once
)
submit_button.click(
fn=gradio_interface,
inputs=[image_input, prompt_input],
outputs=[output_image]
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
debug=True,
)