from PIL import Image, ImageDraw, ImageFont import os import base64 import json import requests from io import BytesIO import threading from datetime import datetime import paho.mqtt.client as mqtt import gradio as gr from api import predict_image # Constants and configuration IMAGE_PATH = "received_image.jpg" IMAGE_HISTORY_DIR = "image_history" MAX_HISTORY_SIZE = 100 MQTT_CONFIG = { "broker": "47.254.33.128", "port": 1883, "topic": "x1/bugs", "username": "my", "password": "my123456" } # Global variables mqtt_client = None latest_image_info = {"path": None, "date": None, "objnum": None} image_history = [] mqtt_status = "MQTT Disconnected" current_prompt = "all" current_task = "" task_name = { "detect all objects": "", "detect by vocabulary": "", "detect by phrase": "" } # Create directories if not exist os.makedirs(IMAGE_HISTORY_DIR, exist_ok=True) # MQTT callback functions def on_connect(client, userdata, flags, rc): global mqtt_status if rc == 0: client.subscribe(MQTT_CONFIG["topic"]) mqtt_status = "MQTT Connected" else: mqtt_status = "MQTT Disconnected" def on_disconnect(client, userdata, rc): global mqtt_status mqtt_status = "MQTT Disconnected" def on_message(client, userdata, msg): threading.Thread(target=handle_message, args=(msg,)).start() def handle_message(msg): try: print("Received message") data = json.loads(msg.payload) image_data = data["values"]["image"].split(",")[1] localtime = data["values"]["localtime"] image = Image.open(BytesIO(base64.b64decode(image_data))) if image.mode == "RGBA": image = image.convert("RGB") image.save(IMAGE_PATH) image_history_path = os.path.join(IMAGE_HISTORY_DIR, f"{localtime.replace(' ', '_').replace(':', '-')}.jpg") image.save(image_history_path) prediction = predict_image_json(image, current_task, current_prompt) annotated_image_path = annotate_image(image, prediction, current_task) detected_objects = predicted_objects_num(prediction, current_task) latest_image_info.update({ "path": annotated_image_path, "date": localtime, "objnum": detected_objects }) image_history.append((image_history_path, localtime)) manage_history_size() except Exception as e: print(f"Error processing message: {e}") def convert_to_od_format(data): bboxes = data.get('bboxes', []) labels = data.get('bboxes_labels', []) od_results = { 'bboxes': bboxes, 'labels': labels } return od_results def predict_image_json(image, task, prompt): msgid = str(datetime.now().timestamp()) if task == "": prompt = "" prediction = predict_image(image, task, prompt) if task == "": prediction[task] = convert_to_od_format(prediction[task]) return prediction def annotate_image(image, prediction, task): draw = ImageDraw.Draw(image) width, height = image.size scale = max(width, height) / 1000 # Scale factor based on image size font_size = int(30 * scale) # Scale font size line_width = int(3 * scale) # Scale line width try: font = ImageFont.truetype("DejaVuSans.ttf", font_size) except IOError: font = ImageFont.load_default() for bbox, label in zip(prediction[task]["bboxes"], prediction[task]["labels"]): x1, y1, x2, y2 = bbox draw.rectangle([x1, y1, x2, y2], outline="yellow", width=line_width) text_bbox = draw.textbbox((x1, y1), label, font=font) draw.rectangle([text_bbox[0], text_bbox[1], text_bbox[2], text_bbox[3]], fill="black") draw.text((x1, y1), label, fill="white", font=font) annotated_image_path = IMAGE_PATH.replace(".jpg", "_annotated.jpg") image.save(annotated_image_path) return annotated_image_path def predicted_objects_num(prediction, task): return len(prediction[task]["bboxes"]) def start_mqtt_client(broker, port, topic, username, password): global mqtt_client if mqtt_client is not None: mqtt_client.disconnect() mqtt_client = mqtt.Client() mqtt_client.username_pw_set(username, password) mqtt_client.on_connect = on_connect mqtt_client.on_disconnect = on_disconnect mqtt_client.on_message = on_message mqtt_client.connect(broker, port, 60) mqtt_client.loop_start() def display_image(): print("Displaying latest image...") return latest_image_info["path"], latest_image_info["objnum"] def display_image_history(): return [(path, date) for path, date in image_history] def show_prediction_on_history(evt: gr.SelectData): image_path = image_history[int(evt.index)][0] image = Image.open(image_path) image.save(IMAGE_PATH) prediction = predict_image_json(image, current_task, current_prompt) annotated_image_path = annotate_image(image, prediction, current_task) predicted_objects = predicted_objects_num(prediction, current_task) latest_image_info["path"] = annotated_image_path latest_image_info["objnum"] = predicted_objects return annotated_image_path, predicted_objects def update_mqtt_config(broker, port, topic, username, password): start_mqtt_client(broker, int(port), topic, username, password) return f"Connected to {broker}:{port}, subscribed to topic '{topic}'" def auto_connect(): update_mqtt_config( MQTT_CONFIG["broker"], MQTT_CONFIG["port"], MQTT_CONFIG["topic"], MQTT_CONFIG["username"], MQTT_CONFIG["password"] ) def history_image_load(): global image_history image_history = [] for filename in os.listdir(IMAGE_HISTORY_DIR): if filename.endswith(".jpg"): image_history.append((os.path.join(IMAGE_HISTORY_DIR, filename), filename.replace("_", " ").replace("-", ":"))) image_history.sort(key=lambda x: x[1]) manage_history_size() def get_mqtt_status(): return mqtt_status def upload_image(filepath): image = Image.open(filepath) if image.mode == "RGBA": image = image.convert("RGB") image.save(IMAGE_PATH) localtime = datetime.now().strftime('%Y-%m-%d %H:%M:%S') image_history_path = os.path.join(IMAGE_HISTORY_DIR, f"{localtime.replace(' ', '_').replace(':', '-')}.jpg") image.save(image_history_path) prediction = predict_image_json(image, current_task, current_prompt) annotated_image_path = annotate_image(image, prediction, current_task) predicted_objects = predicted_objects_num(prediction, current_task) latest_image_info.update({ "path": annotated_image_path, "date": localtime, "objnum": predicted_objects }) image_history.append((image_history_path, localtime)) manage_history_size() return annotated_image_path, predicted_objects, display_image_history() def manage_history_size(): global image_history if len(image_history) > MAX_HISTORY_SIZE: for i in range(2): os.remove(image_history.pop(0)[0]) def commit_prompt(prompt): global current_prompt print(f"Updating prompt to: {prompt}") if prompt == "": prompt = "all" current_prompt = prompt image = Image.open(IMAGE_PATH) prediction = predict_image_json(image, current_task, current_prompt) annotated_image_path = annotate_image(image, prediction, current_task) predicted_objects = predicted_objects_num(prediction, current_task) latest_image_info["path"] = annotated_image_path latest_image_info["objnum"] = predicted_objects return annotated_image_path, predicted_objects def update_task(task, prompt): global current_task task = task_name[task] current_task = task if task == "": current_prompt = "" else: current_prompt = prompt print(f"Updating task to: {task}, prompt to: {current_prompt}") return gr.update(visible=task != "") with gr.Blocks(css="footer {visibility: hidden}") as iface: gr.Markdown("## MS Computer Vision Demo") mqtt_status_output = gr.HTML(value=mqtt_status) with gr.Accordion("MQTT Settings", open=False): with gr.Row(): broker_input = gr.Textbox(label="MQTT Broker", value=MQTT_CONFIG["broker"]) port_input = gr.Textbox(label="MQTT Port", value=str(MQTT_CONFIG["port"])) topic_input = gr.Textbox(label="MQTT Topic", value=MQTT_CONFIG["topic"]) with gr.Row(): username_input = gr.Textbox(label="MQTT Username", value=MQTT_CONFIG["username"]) password_input = gr.Textbox(label="MQTT Password", type="password", value=MQTT_CONFIG["password"]) connect_button = gr.Button("Connect") connect_button.click( fn=update_mqtt_config, inputs=[broker_input, port_input, topic_input, username_input, password_input], outputs=[] ) with gr.Row(): with gr.Column(scale=2): image_output = gr.Image(label="Latest Image") detected_objects_output = gr.Textbox(label="Detected Objects Count", placeholder="No objects detected", interactive=False) task_input = gr.Dropdown( label="Task", choices=list(task_name.keys()), value="detect all objects" ) prompt_input = gr.Textbox(label="Prompt(Optional)", placeholder="what is object want to detect?", visible=False) task_input.change(fn=update_task, inputs=[task_input, prompt_input], outputs=[prompt_input]) commit_button = gr.Button("Commit") commit_button.click(fn=commit_prompt, inputs=[prompt_input], outputs=[image_output, detected_objects_output]) with gr.Column(scale=1): history_output = gr.Gallery(label="History Image", columns=3) upload_button = gr.UploadButton(label="Upload Image", file_types=["image"]) upload_button.upload(fn=upload_image, inputs=upload_button, outputs=[image_output, detected_objects_output, history_output]) def refresh_interface(): return display_image() def refresh_history(): return display_image_history() history_output.change(fn=refresh_interface, outputs=[image_output, detected_objects_output]) history_image_load() iface.load(fn=refresh_history, inputs=[], outputs=history_output, every=0.5) auto_connect() iface.load(fn=get_mqtt_status, inputs=[], outputs=mqtt_status_output) history_output.select(fn=show_prediction_on_history, outputs=[image_output, detected_objects_output]) iface.launch(share=True)