yhzheng1031 commited on
Commit
e6f4169
·
verified ·
1 Parent(s): 4da1734

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -64
README.md CHANGED
@@ -22,20 +22,18 @@ pip install transformers==4.57.0
22
  ```
23
 
24
  ```python
25
- import os
26
- import re
27
- import json
28
- import math
29
-
30
  import torch
31
- from PIL import Image, ImageDraw
32
  from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
33
- from playwright.sync_api import sync_playwright
34
 
35
  from prompt_builder import SYSTEM_PROMPT, build_user_prompt
36
  from visual_hint import build_visual_hint
37
- from render_utils import render_html_to_image, save_demo_outputs
 
38
 
 
 
 
39
 
40
  MODEL_NAME = "GD-ML/Code2World"
41
 
@@ -45,34 +43,25 @@ model = Qwen3VLForConditionalGeneration.from_pretrained(
45
  attn_implementation="flash_attention_2",
46
  device_map="auto",
47
  )
48
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
49
-
50
-
51
 
52
- def extract_clean_html(text: str) -> str:
53
- text = text.replace("```html", "").replace("```", "")
54
- start_match = re.search(r"<!DOCTYPE html>", text, re.IGNORECASE)
55
- end_match = re.search(r"</html>", text, re.IGNORECASE)
56
-
57
- if start_match and end_match:
58
- start_idx = start_match.start()
59
- end_idx = end_match.end()
60
- if end_idx > start_idx:
61
- return text[start_idx:end_idx]
62
-
63
- return text.strip()
64
 
65
 
 
 
 
66
 
67
- def build_messages(image: Image.Image, instruction: str, action: dict, semantic_desc=None):
68
  user_prompt = build_user_prompt(
69
  instruction_str=instruction,
70
  action=action,
71
- semantic_desc=semantic_desc,
72
  )
73
 
74
- return [
75
- {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
 
 
 
76
  {
77
  "role": "user",
78
  "content": [
@@ -81,15 +70,15 @@ def build_messages(image: Image.Image, instruction: str, action: dict, semantic_
81
  ],
82
  },
83
  ]
 
84
 
85
 
86
  @torch.inference_mode()
87
- def generate_html(image: Image.Image, instruction: str, action: dict, semantic_desc=None, max_new_tokens: int = 8192):
88
  messages = build_messages(
89
  image=image,
90
  instruction=instruction,
91
  action=action,
92
- semantic_desc=semantic_desc,
93
  )
94
 
95
  inputs = processor.apply_chat_template(
@@ -101,9 +90,14 @@ def generate_html(image: Image.Image, instruction: str, action: dict, semantic_d
101
  )
102
  inputs = inputs.to(model.device)
103
 
104
- generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
 
105
  generated_ids_trimmed = [
106
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
107
  ]
108
 
109
  output_text = processor.batch_decode(
@@ -112,58 +106,60 @@ def generate_html(image: Image.Image, instruction: str, action: dict, semantic_d
112
  clean_up_tokenization_spaces=False,
113
  )[0]
114
 
115
- return extract_clean_html(output_text)
 
116
 
117
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- def run_demo(
120
- image_path: str,
121
- instruction: str,
122
- action: dict,
123
- step_pam: dict | None = None,
124
- semantic_desc: str | None = None,
125
- use_visual_hint: bool = True,
126
- max_new_tokens: int = 8192,
127
- output_dir: str = "./demo_outputs",
128
- ):
129
  image = Image.open(image_path).convert("RGB")
130
-
131
- if use_visual_hint:
132
- hinted_image = build_visual_hint(image, action, step_pam)
133
- else:
134
- hinted_image = image
135
 
136
  html = generate_html(
137
  image=hinted_image,
138
  instruction=instruction,
139
  action=action,
140
- semantic_desc=semantic_desc,
141
- max_new_tokens=max_new_tokens,
142
  )
143
 
144
  rendered_image = render_html_to_image(html)
145
- save_demo_outputs(output_dir, hinted_image, html, rendered_image)
 
 
 
 
 
 
146
 
147
  return hinted_image, html, rendered_image
148
 
149
 
 
 
 
 
150
  if __name__ == "__main__":
151
- example_action = {
152
- "action_type": "click",
153
- "x": 540,
154
- "y": 320,
155
- }
156
- example_step_pam = {
157
- "coordinate": [540, 320]
 
 
 
158
  }
159
 
160
- run_demo(
161
- image_path="./examples/current.png",
162
- instruction="Tap the search bar to start searching.",
163
- action=example_action,
164
- step_pam=example_step_pam,
165
- output_dir="./demo_outputs",
166
- )
167
  ```
168
 
169
  ## Citation
 
22
  ```
23
 
24
  ```python
 
 
 
 
 
25
  import torch
26
+ from PIL import Image
27
  from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
 
28
 
29
  from prompt_builder import SYSTEM_PROMPT, build_user_prompt
30
  from visual_hint import build_visual_hint
31
+ from render_utils import extract_clean_html, render_html_to_image, save_demo_outputs
32
+
33
 
34
+ # ============================================================
35
+ # 1. Load model
36
+ # ============================================================
37
 
38
  MODEL_NAME = "GD-ML/Code2World"
39
 
 
43
  attn_implementation="flash_attention_2",
44
  device_map="auto",
45
  )
 
 
 
46
 
47
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
+ # ============================================================
51
+ # 2. Helper functions
52
+ # ============================================================
53
 
54
+ def build_messages(image, instruction, action):
55
  user_prompt = build_user_prompt(
56
  instruction_str=instruction,
57
  action=action,
 
58
  )
59
 
60
+ messages = [
61
+ {
62
+ "role": "system",
63
+ "content": [{"type": "text", "text": SYSTEM_PROMPT}],
64
+ },
65
  {
66
  "role": "user",
67
  "content": [
 
70
  ],
71
  },
72
  ]
73
+ return messages
74
 
75
 
76
  @torch.inference_mode()
77
+ def generate_html(image, instruction, action, max_new_tokens=8192):
78
  messages = build_messages(
79
  image=image,
80
  instruction=instruction,
81
  action=action,
 
82
  )
83
 
84
  inputs = processor.apply_chat_template(
 
90
  )
91
  inputs = inputs.to(model.device)
92
 
93
+ generated_ids = model.generate(
94
+ **inputs,
95
+ max_new_tokens=max_new_tokens,
96
+ )
97
+
98
  generated_ids_trimmed = [
99
+ out_ids[len(in_ids):]
100
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
101
  ]
102
 
103
  output_text = processor.batch_decode(
 
106
  clean_up_tokenization_spaces=False,
107
  )[0]
108
 
109
+ html = extract_clean_html(output_text)
110
+ return html
111
 
112
 
113
+ def run_demo(case_data, output_dir="./demo_outputs"):
114
+ """
115
+ case_data 只需要这些 key:
116
+ - images[0]
117
+ - instruction
118
+ - action
119
+ """
120
+ image_path = case_data["images"][0]
121
+ instruction = case_data["instruction"]
122
+ action = case_data["action"]
123
 
 
 
 
 
 
 
 
 
 
 
124
  image = Image.open(image_path).convert("RGB")
125
+ hinted_image = build_visual_hint(image, action)
 
 
 
 
126
 
127
  html = generate_html(
128
  image=hinted_image,
129
  instruction=instruction,
130
  action=action,
 
 
131
  )
132
 
133
  rendered_image = render_html_to_image(html)
134
+
135
+ save_demo_outputs(
136
+ output_dir=output_dir,
137
+ hinted_image=hinted_image,
138
+ html=html,
139
+ rendered_image=rendered_image,
140
+ )
141
 
142
  return hinted_image, html, rendered_image
143
 
144
 
145
+ # ============================================================
146
+ # 3. Example case
147
+ # ============================================================
148
+
149
  if __name__ == "__main__":
150
+ case_data = {
151
+ "images": [
152
+ "/mnt/workspace/zyh/wm_ability/android_control/test_images_mini/904_7.png"
153
+ ],
154
+ "instruction": "Click on the Search Omio button.",
155
+ "action": {
156
+ "action_type": "click",
157
+ "x": 540,
158
+ "y": 1470
159
+ }
160
  }
161
 
162
+ run_demo(case_data, output_dir="./demo_outputs")
 
 
 
 
 
 
163
  ```
164
 
165
  ## Citation