Upload caption_aitw_v2.py with huggingface_hub
Browse files- caption_aitw_v2.py +95 -0
caption_aitw_v2.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from PIL import Image
|
5 |
+
import pprint
|
6 |
+
from tqdm import tqdm
|
7 |
+
from multiprocessing import Pool, cpu_count
|
8 |
+
|
9 |
+
|
10 |
+
from chat import MiniCPMVChat, img2base64
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def read_json(file_path):
|
15 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
16 |
+
data = json.load(file)
|
17 |
+
return data
|
18 |
+
|
19 |
+
def write_json(file_path, data):
|
20 |
+
with open(file_path, 'w', encoding='utf-8') as file:
|
21 |
+
json.dump(data, file, ensure_ascii=False, indent=4)
|
22 |
+
|
23 |
+
def preprocess_data(data, path_base):
|
24 |
+
"""将图像路径替换为 base64 编码,减少重复 I/O。"""
|
25 |
+
for item in data:
|
26 |
+
img_path = os.path.join(path_base, item['image'])
|
27 |
+
item['image_base64'] = img2base64(img_path)
|
28 |
+
return data
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def chat_minicpm_application(image_path):
|
33 |
+
|
34 |
+
qs = """
|
35 |
+
List the names and locations of all interactive applications in the image, as well as their functionality and potential applications.
|
36 |
+
"""
|
37 |
+
# qs = f'''{context}. The green frame in the picture represents the situation of clicking, need to explain why click in the corresponding area.
|
38 |
+
# '''
|
39 |
+
im_64 = img2base64(image_path)
|
40 |
+
msgs = [{"role": "user", "content": qs}]
|
41 |
+
inputs = {"image": im_64, "question": json.dumps(msgs)}
|
42 |
+
answer = chat_model.chat(inputs)
|
43 |
+
return answer
|
44 |
+
|
45 |
+
|
46 |
+
def chat_minicpm_content(image_path):
|
47 |
+
|
48 |
+
qs = """
|
49 |
+
Describe the content of this image.
|
50 |
+
"""
|
51 |
+
|
52 |
+
im_64 = img2base64(image_path)
|
53 |
+
msgs = [{"role": "user", "content": qs}]
|
54 |
+
inputs = {"image": im_64, "question": json.dumps(msgs)}
|
55 |
+
answer = chat_model.chat(inputs)
|
56 |
+
return answer
|
57 |
+
|
58 |
+
def chat_minicpm_mind(image_path):
|
59 |
+
|
60 |
+
qs = """
|
61 |
+
The green frame in the picture represents the situation of clicking, need to explain why click in the corresponding area. Answer template: The green box ....
|
62 |
+
"""
|
63 |
+
|
64 |
+
im_64 = img2base64(image_path)
|
65 |
+
msgs = [{"role": "user", "content": qs}]
|
66 |
+
inputs = {"image": im_64, "question": json.dumps(msgs)}
|
67 |
+
answer = chat_model.chat(inputs)
|
68 |
+
return answer
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
torch.manual_seed(0)
|
73 |
+
chat_model = MiniCPMVChat('/code/Model/MiniCPM-Llama3-V-2_5')
|
74 |
+
path_base = '/code/Auto-GUI/dataset/'
|
75 |
+
|
76 |
+
|
77 |
+
data = read_json("/code/Auto-GUI/dataset/mind/general_blip_train_llava_coco.json")
|
78 |
+
data = [line for line in data if line['action_type'] == '#DUAL_POINT#'][17370:]
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
for idx, i in enumerate(tqdm(data), 1): # 从1开始计数,便于后续计数判断
|
83 |
+
img_path = path_base + i['image']
|
84 |
+
# context = data[idx]['conversations'][0]['value']
|
85 |
+
i['application'] = chat_minicpm_application(img_path)
|
86 |
+
i['content'] = chat_minicpm_content(img_path)
|
87 |
+
i['mind'] = chat_minicpm_mind(img_path)
|
88 |
+
|
89 |
+
# 每100次保存一次
|
90 |
+
if idx % 100 == 0:
|
91 |
+
write_json('/code/MiniCPM-V/general_blip_train_llava_coco_caption_mind2.json', data)
|
92 |
+
|
93 |
+
# 最后保存一次,确保未满100的剩余数据也能保存
|
94 |
+
write_json('/code/MiniCPM-V/general_blip_train_llava_coco_caption_mind2.json', data)
|
95 |
+
|