import io import os import shutil import base64 import gradio as gr from PIL import Image, ImageDraw from MobileAgent.text_localization import ocr from MobileAgent.icon_localization import det from MobileAgent.local_server import mobile_agent_infer from modelscope import snapshot_download from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks chatbot_css = """ """ temp_file = "temp" screenshot = "screenshot" cache = "cache" if not os.path.exists(temp_file): os.mkdir(temp_file) if not os.path.exists(screenshot): os.mkdir(screenshot) if not os.path.exists(cache): os.mkdir(cache) groundingdino_dir = snapshot_download('AI-ModelScope/GroundingDINO', revision='v1.0.0') groundingdino_model = pipeline('grounding-dino-task', model=groundingdino_dir) ocr_detection = pipeline(Tasks.ocr_detection, model='damo/cv_resnet18_ocr-detection-line-level_damo') ocr_recognition = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-document_damo') def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') def get_all_files_in_folder(folder_path): file_list = [] for file_name in os.listdir(folder_path): file_list.append(file_name) return file_list def crop(image, box, i): image = Image.open(image) x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3]) if x1 >= x2-10 or y1 >= y2-10: return cropped_image = image.crop((x1, y1, x2, y2)) cropped_image.save(f"./temp/{i}.png", format="PNG") def merge_text_blocks(text_list, coordinates_list): merged_text_blocks = [] merged_coordinates = [] sorted_indices = sorted(range(len(coordinates_list)), key=lambda k: (coordinates_list[k][1], coordinates_list[k][0])) sorted_text_list = [text_list[i] for i in sorted_indices] sorted_coordinates_list = [coordinates_list[i] for i in sorted_indices] num_blocks = len(sorted_text_list) merge = [False] * num_blocks for i in range(num_blocks): if merge[i]: continue anchor = i group_text = [sorted_text_list[anchor]] group_coordinates = [sorted_coordinates_list[anchor]] for j in range(i+1, num_blocks): if merge[j]: continue if abs(sorted_coordinates_list[anchor][0] - sorted_coordinates_list[j][0]) < 10 and \ sorted_coordinates_list[j][1] - sorted_coordinates_list[anchor][3] >= -10 and sorted_coordinates_list[j][1] - sorted_coordinates_list[anchor][3] < 30 and \ abs(sorted_coordinates_list[anchor][3] - sorted_coordinates_list[anchor][1] - (sorted_coordinates_list[j][3] - sorted_coordinates_list[j][1])) < 10: group_text.append(sorted_text_list[j]) group_coordinates.append(sorted_coordinates_list[j]) merge[anchor] = True anchor = j merge[anchor] = True merged_text = "\n".join(group_text) min_x1 = min(group_coordinates, key=lambda x: x[0])[0] min_y1 = min(group_coordinates, key=lambda x: x[1])[1] max_x2 = max(group_coordinates, key=lambda x: x[2])[2] max_y2 = max(group_coordinates, key=lambda x: x[3])[3] merged_text_blocks.append(merged_text) merged_coordinates.append([min_x1, min_y1, max_x2, max_y2]) return merged_text_blocks, merged_coordinates def get_perception_infos(screenshot_file): width, height = Image.open(screenshot_file).size text, coordinates = ocr(screenshot_file, ocr_detection, ocr_recognition) text, coordinates = merge_text_blocks(text, coordinates) perception_infos = [] for i in range(len(coordinates)): perception_info = {"text": "text: " + text[i], "coordinates": coordinates[i]} perception_infos.append(perception_info) coordinates = det(screenshot_file, "icon", groundingdino_model) for i in range(len(coordinates)): perception_info = {"text": "icon", "coordinates": coordinates[i]} perception_infos.append(perception_info) image_box = [] image_id = [] for i in range(len(perception_infos)): if perception_infos[i]['text'] == 'icon': image_box.append(perception_infos[i]['coordinates']) image_id.append(i) for i in range(len(image_box)): crop(screenshot_file, image_box[i], image_id[i]) images = get_all_files_in_folder(temp_file) if len(images) > 0: images = sorted(images, key=lambda x: int(x.split('/')[-1].split('.')[0])) image_id = [int(image.split('/')[-1].split('.')[0]) for image in images] icon_map = {} prompt = 'This image is an icon from a phone screen. Please briefly describe the shape and color of this icon in one sentence.' string_image = [] for i in range(len(images)): image_path = os.path.join(temp_file, images[i]) string_image.append({"image_name": images[i], "image_file": encode_image(image_path)}) query_data = {"task": "caption", "images": string_image, "query": prompt} response_query = mobile_agent_infer(query_data) icon_map = response_query["icon_map"] for i, j in zip(image_id, range(1, len(image_id)+1)): if icon_map.get(str(j)): perception_infos[i]['text'] = "icon: " + icon_map[str(j)] for i in range(len(perception_infos)): perception_infos[i]['coordinates'] = [int((perception_infos[i]['coordinates'][0]+perception_infos[i]['coordinates'][2])/2), int((perception_infos[i]['coordinates'][1]+perception_infos[i]['coordinates'][3])/2)] return perception_infos, width, height def image_to_base64(image): buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") img_html = f'' return img_html def chatbot(image, instruction, add_info, history, chat_log): if history == {}: thought_history = [] summary_history = [] action_history = [] summary = "" action = "" completed_requirements = "" memory = "" insight = "" error_flag = False user_msg = "
".format(instruction) else: thought_history = history["thought_history"] summary_history = history["summary_history"] action_history = history["action_history"] summary = history["summary"] action = history["action"] completed_requirements = history["completed_requirements"] memory = history["memory"][0] insight = history["insight"] error_flag = history["error_flag"] user_msg = " ".format("I have uploaded the screenshot. Please continue operating.") images = get_all_files_in_folder(cache) if len(images) > 0 and len(images) <= 100: images = sorted(images, key=lambda x: int(x.split('/')[-1].split('.')[0])) image_id = [int(image.split('/')[-1].split('.')[0]) for image in images] cur_image_id = image_id[-1] + 1 elif len(images) > 100: images = sorted(images, key=lambda x: int(x.split('/')[-1].split('.')[0])) image_id = [int(image.split('/')[-1].split('.')[0]) for image in images] cur_image_id = image_id[-1] + 1 os.remove(os.path.join(cache, str(image_id[0])+".png")) else: cur_image_id = 1 image.save(os.path.join(cache, str(cur_image_id) + ".png"), format="PNG") screenshot_file = os.path.join(cache, str(cur_image_id) + ".png") perception_infos, width, height = get_perception_infos(screenshot_file) shutil.rmtree(temp_file) os.mkdir(temp_file) local_screenshot_file = encode_image(screenshot_file) query_data = { "task": "decision", "screenshot_file": local_screenshot_file, "instruction": instruction, "perception_infos": perception_infos, "width": width, "height": height, "summary_history": summary_history, "action_history": action_history, "summary": summary, "action": action, "add_info": add_info, "error_flag": error_flag, "completed_requirements": completed_requirements, "memory": memory, "memory_switch": True, "insight": insight } response_query = mobile_agent_infer(query_data) output_action = response_query["decision"] output_memory = response_query["memory"] if output_action == "No token": bot_response = [" ".format("Sorry, the resources can be exhausted today.")] chat_html = "