|
from PIL import Image, ImageDraw, ImageFont |
|
import os |
|
import base64 |
|
import json |
|
import requests |
|
from io import BytesIO |
|
import threading |
|
from datetime import datetime |
|
import paho.mqtt.client as mqtt |
|
import gradio as gr |
|
from api import predict_image |
|
|
|
|
|
IMAGE_PATH = "received_image.jpg" |
|
IMAGE_HISTORY_DIR = "image_history" |
|
MAX_HISTORY_SIZE = 100 |
|
MQTT_CONFIG = { |
|
"broker": "47.254.33.128", |
|
"port": 1883, |
|
"topic": "x1/bugs", |
|
"username": "my", |
|
"password": "my123456" |
|
} |
|
|
|
|
|
mqtt_client = None |
|
latest_image_info = {"path": None, "date": None, "objnum": None} |
|
image_history = [] |
|
mqtt_status = "<span style='color: red;'>MQTT Disconnected</span>" |
|
current_prompt = "all" |
|
current_task = "<OD>" |
|
|
|
task_name = { |
|
"detect all objects": "<OD>", |
|
"detect by vocabulary": "<OPEN_VOCABULARY_DETECTION>", |
|
"detect by phrase": "<CAPTION_TO_PHRASE_GROUNDING>" |
|
} |
|
|
|
|
|
os.makedirs(IMAGE_HISTORY_DIR, exist_ok=True) |
|
|
|
|
|
def on_connect(client, userdata, flags, rc): |
|
global mqtt_status |
|
if rc == 0: |
|
client.subscribe(MQTT_CONFIG["topic"]) |
|
mqtt_status = "<span style='color: green;'>MQTT Connected</span>" |
|
else: |
|
mqtt_status = "<span style='color: red;'>MQTT Disconnected</span>" |
|
|
|
def on_disconnect(client, userdata, rc): |
|
global mqtt_status |
|
mqtt_status = "<span style='color: red;'>MQTT Disconnected</span>" |
|
|
|
def on_message(client, userdata, msg): |
|
threading.Thread(target=handle_message, args=(msg,)).start() |
|
|
|
def handle_message(msg): |
|
try: |
|
print("Received message") |
|
data = json.loads(msg.payload) |
|
image_data = data["values"]["image"].split(",")[1] |
|
localtime = data["values"]["localtime"] |
|
|
|
image = Image.open(BytesIO(base64.b64decode(image_data))) |
|
if image.mode == "RGBA": |
|
image = image.convert("RGB") |
|
image.save(IMAGE_PATH) |
|
|
|
image_history_path = os.path.join(IMAGE_HISTORY_DIR, f"{localtime.replace(' ', '_').replace(':', '-')}.jpg") |
|
image.save(image_history_path) |
|
|
|
prediction = predict_image_json(image, current_task, current_prompt) |
|
annotated_image_path = annotate_image(image, prediction, current_task) |
|
detected_objects = predicted_objects_num(prediction, current_task) |
|
|
|
latest_image_info.update({ |
|
"path": annotated_image_path, |
|
"date": localtime, |
|
"objnum": detected_objects |
|
}) |
|
|
|
image_history.append((image_history_path, localtime)) |
|
manage_history_size() |
|
except Exception as e: |
|
print(f"Error processing message: {e}") |
|
|
|
def convert_to_od_format(data): |
|
bboxes = data.get('bboxes', []) |
|
labels = data.get('bboxes_labels', []) |
|
od_results = { |
|
'bboxes': bboxes, |
|
'labels': labels |
|
} |
|
return od_results |
|
|
|
def predict_image_json(image, task, prompt): |
|
msgid = str(datetime.now().timestamp()) |
|
if task == "<OD>": |
|
prompt = "" |
|
prediction = predict_image(image, task, prompt) |
|
if task == "<OPEN_VOCABULARY_DETECTION>": |
|
prediction[task] = convert_to_od_format(prediction[task]) |
|
return prediction |
|
|
|
def annotate_image(image, prediction, task): |
|
draw = ImageDraw.Draw(image) |
|
width, height = image.size |
|
scale = max(width, height) / 1000 |
|
font_size = int(30 * scale) |
|
line_width = int(3 * scale) |
|
try: |
|
font = ImageFont.truetype("DejaVuSans.ttf", font_size) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
|
|
for bbox, label in zip(prediction[task]["bboxes"], prediction[task]["labels"]): |
|
x1, y1, x2, y2 = bbox |
|
draw.rectangle([x1, y1, x2, y2], outline="yellow", width=line_width) |
|
text_bbox = draw.textbbox((x1, y1), label, font=font) |
|
draw.rectangle([text_bbox[0], text_bbox[1], text_bbox[2], text_bbox[3]], fill="black") |
|
draw.text((x1, y1), label, fill="white", font=font) |
|
|
|
annotated_image_path = IMAGE_PATH.replace(".jpg", "_annotated.jpg") |
|
image.save(annotated_image_path) |
|
return annotated_image_path |
|
|
|
def predicted_objects_num(prediction, task): |
|
return len(prediction[task]["bboxes"]) |
|
|
|
def start_mqtt_client(broker, port, topic, username, password): |
|
global mqtt_client |
|
if mqtt_client is not None: |
|
mqtt_client.disconnect() |
|
mqtt_client = mqtt.Client() |
|
mqtt_client.username_pw_set(username, password) |
|
mqtt_client.on_connect = on_connect |
|
mqtt_client.on_disconnect = on_disconnect |
|
mqtt_client.on_message = on_message |
|
mqtt_client.connect(broker, port, 60) |
|
mqtt_client.loop_start() |
|
|
|
def display_image(): |
|
print("Displaying latest image...") |
|
return latest_image_info["path"], latest_image_info["objnum"] |
|
|
|
def display_image_history(): |
|
return [(path, date) for path, date in image_history] |
|
|
|
def show_prediction_on_history(evt: gr.SelectData): |
|
image_path = image_history[int(evt.index)][0] |
|
image = Image.open(image_path) |
|
image.save(IMAGE_PATH) |
|
prediction = predict_image_json(image, current_task, current_prompt) |
|
annotated_image_path = annotate_image(image, prediction, current_task) |
|
predicted_objects = predicted_objects_num(prediction, current_task) |
|
latest_image_info["path"] = annotated_image_path |
|
latest_image_info["objnum"] = predicted_objects |
|
return annotated_image_path, predicted_objects |
|
|
|
def update_mqtt_config(broker, port, topic, username, password): |
|
start_mqtt_client(broker, int(port), topic, username, password) |
|
return f"Connected to {broker}:{port}, subscribed to topic '{topic}'" |
|
|
|
def auto_connect(): |
|
update_mqtt_config( |
|
MQTT_CONFIG["broker"], |
|
MQTT_CONFIG["port"], |
|
MQTT_CONFIG["topic"], |
|
MQTT_CONFIG["username"], |
|
MQTT_CONFIG["password"] |
|
) |
|
|
|
def history_image_load(): |
|
global image_history |
|
image_history = [] |
|
for filename in os.listdir(IMAGE_HISTORY_DIR): |
|
if filename.endswith(".jpg"): |
|
image_history.append((os.path.join(IMAGE_HISTORY_DIR, filename), filename.replace("_", " ").replace("-", ":"))) |
|
image_history.sort(key=lambda x: x[1]) |
|
manage_history_size() |
|
|
|
def get_mqtt_status(): |
|
return mqtt_status |
|
|
|
def upload_image(filepath): |
|
image = Image.open(filepath) |
|
if image.mode == "RGBA": |
|
image = image.convert("RGB") |
|
image.save(IMAGE_PATH) |
|
localtime = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
image_history_path = os.path.join(IMAGE_HISTORY_DIR, f"{localtime.replace(' ', '_').replace(':', '-')}.jpg") |
|
image.save(image_history_path) |
|
prediction = predict_image_json(image, current_task, current_prompt) |
|
annotated_image_path = annotate_image(image, prediction, current_task) |
|
predicted_objects = predicted_objects_num(prediction, current_task) |
|
latest_image_info.update({ |
|
"path": annotated_image_path, |
|
"date": localtime, |
|
"objnum": predicted_objects |
|
}) |
|
image_history.append((image_history_path, localtime)) |
|
manage_history_size() |
|
return annotated_image_path, predicted_objects, display_image_history() |
|
|
|
def manage_history_size(): |
|
global image_history |
|
if len(image_history) > MAX_HISTORY_SIZE: |
|
for i in range(2): |
|
os.remove(image_history.pop(0)[0]) |
|
|
|
def commit_prompt(prompt): |
|
global current_prompt |
|
print(f"Updating prompt to: {prompt}") |
|
if prompt == "": |
|
prompt = "all" |
|
current_prompt = prompt |
|
image = Image.open(IMAGE_PATH) |
|
prediction = predict_image_json(image, current_task, current_prompt) |
|
annotated_image_path = annotate_image(image, prediction, current_task) |
|
predicted_objects = predicted_objects_num(prediction, current_task) |
|
latest_image_info["path"] = annotated_image_path |
|
latest_image_info["objnum"] = predicted_objects |
|
return annotated_image_path, predicted_objects |
|
|
|
def update_task(task, prompt): |
|
global current_task |
|
task = task_name[task] |
|
current_task = task |
|
if task == "<OD>": |
|
current_prompt = "" |
|
else: |
|
current_prompt = prompt |
|
print(f"Updating task to: {task}, prompt to: {current_prompt}") |
|
return gr.update(visible=task != "<OD>") |
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as iface: |
|
gr.Markdown("## MS Computer Vision Demo") |
|
mqtt_status_output = gr.HTML(value=mqtt_status) |
|
|
|
with gr.Accordion("MQTT Settings", open=False): |
|
with gr.Row(): |
|
broker_input = gr.Textbox(label="MQTT Broker", value=MQTT_CONFIG["broker"]) |
|
port_input = gr.Textbox(label="MQTT Port", value=str(MQTT_CONFIG["port"])) |
|
topic_input = gr.Textbox(label="MQTT Topic", value=MQTT_CONFIG["topic"]) |
|
with gr.Row(): |
|
username_input = gr.Textbox(label="MQTT Username", value=MQTT_CONFIG["username"]) |
|
password_input = gr.Textbox(label="MQTT Password", type="password", value=MQTT_CONFIG["password"]) |
|
connect_button = gr.Button("Connect") |
|
connect_button.click( |
|
fn=update_mqtt_config, |
|
inputs=[broker_input, port_input, topic_input, username_input, password_input], |
|
outputs=[] |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
image_output = gr.Image(label="Latest Image") |
|
detected_objects_output = gr.Textbox(label="Detected Objects Count", placeholder="No objects detected", interactive=False) |
|
task_input = gr.Dropdown( |
|
label="Task", |
|
choices=list(task_name.keys()), |
|
value="detect all objects" |
|
) |
|
prompt_input = gr.Textbox(label="Prompt(Optional)", placeholder="what is object want to detect?", visible=False) |
|
task_input.change(fn=update_task, inputs=[task_input, prompt_input], outputs=[prompt_input]) |
|
commit_button = gr.Button("Commit") |
|
commit_button.click(fn=commit_prompt, inputs=[prompt_input], outputs=[image_output, detected_objects_output]) |
|
with gr.Column(scale=1): |
|
history_output = gr.Gallery(label="History Image", columns=3) |
|
upload_button = gr.UploadButton(label="Upload Image", file_types=["image"]) |
|
upload_button.upload(fn=upload_image, inputs=upload_button, outputs=[image_output, detected_objects_output, history_output]) |
|
|
|
def refresh_interface(): |
|
return display_image() |
|
|
|
def refresh_history(): |
|
return display_image_history() |
|
|
|
history_output.change(fn=refresh_interface, outputs=[image_output, detected_objects_output]) |
|
|
|
history_image_load() |
|
iface.load(fn=refresh_history, inputs=[], outputs=history_output, every=0.5) |
|
|
|
auto_connect() |
|
iface.load(fn=get_mqtt_status, inputs=[], outputs=mqtt_status_output) |
|
history_output.select(fn=show_prediction_on_history, outputs=[image_output, detected_objects_output]) |
|
|
|
iface.launch(share=True) |
|
|