| import torch |
| import json |
| import re |
| import os |
| from qwen_vl_utils import process_vision_info |
| from transformers import ( |
| Qwen2VLForConditionalGeneration, |
| LogitsProcessor, |
| LogitsProcessorList, |
| AutoModelForCausalLM, |
| AutoTokenizer |
| ) |
| from gui_actor.constants import ( |
| DEFAULT_POINTER_END_TOKEN, |
| DEFAULT_POINTER_PAD_TOKEN, |
| chat_template |
| ) |
|
|
| class ForceFollowTokensLogitsProcessor(LogitsProcessor): |
| """ |
| Forces tokens B (pointer_pad_token) and C (pointer_end_token) to follow token A (pointer_start_token). |
| Whenever token_a_id is generated, enqueue the forced_sequence (e.g. [B, C]). |
| As long as forced tokens remain in the queue, force them in the output. |
| """ |
| def __init__(self, token_a_id, forced_sequence=[DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN]): |
| super().__init__() |
| self.token_a_id = token_a_id |
| self.forced_sequence = forced_sequence |
| self.force_queue = [] |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| """ |
| Called at each decoding step to modify `scores`. |
| |
| Args: |
| input_ids: shape (batch_size, seq_len). The already-decoded tokens. |
| scores: shape (batch_size, vocab_size). Model logits for the next token. |
| """ |
| batch_size = input_ids.shape[0] |
| if batch_size > 1: |
| raise NotImplementedError("Batch size must be 1 for this logits processor.") |
| |
| |
| |
| last_token_id = input_ids[0, -1].item() |
|
|
| |
| if last_token_id == self.token_a_id: |
| self.force_queue.extend(self.forced_sequence) |
| |
| |
| if len(self.force_queue) > 0: |
| forced_token = self.force_queue.pop(0) |
| |
| new_scores = torch.full_like(scores, float('-inf')) |
| new_scores[0, forced_token] = 0.0 |
| return new_scores |
| |
| |
| return scores |
|
|
|
|
| def get_prediction_region_point(attn_scores, n_width, n_height, top_n=30, activation_threshold=0.3, return_all_regions=True, rect_center=False): |
| """ |
| 1. Select activated patches |
| 2. Divide connected patches into different regions |
| 3. Calculate the average activation value for each region |
| 4. Select the region with the highest average activation value |
| 5. Return the center point of that region as the final prediction point |
| """ |
|
|
| |
| |
| max_score = attn_scores[0].max().item() |
| threshold = max_score * activation_threshold |
| |
| mask = attn_scores[0] > threshold |
| valid_indices = torch.nonzero(mask).squeeze(-1) |
| topk_values = attn_scores[0][valid_indices] |
| topk_indices = valid_indices |
| |
| |
| topk_coords = [] |
| for idx in topk_indices.tolist(): |
| y = idx // n_width |
| x = idx % n_width |
| topk_coords.append((y, x, idx)) |
| |
| |
| regions = [] |
| visited = set() |
| for i, (y, x, idx) in enumerate(topk_coords): |
| if idx in visited: |
| continue |
| |
| |
| region = [(y, x, idx, topk_values[i].item())] |
| visited.add(idx) |
| queue = [(y, x, idx, topk_values[i].item())] |
| |
| |
| while queue: |
| cy, cx, c_idx, c_val = queue.pop(0) |
| |
| |
| for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]: |
| ny, nx = cy + dy, cx + dx |
| n_idx = ny * n_width + nx |
| |
| |
| for j, (ty, tx, t_idx) in enumerate(topk_coords): |
| if ty == ny and tx == nx and t_idx not in visited: |
| visited.add(t_idx) |
| region.append((ny, nx, t_idx, topk_values[j].item())) |
| queue.append((ny, nx, t_idx, topk_values[j].item())) |
| |
| regions.append(region) |
| |
| |
| region_scores = [] |
| region_centers = [] |
| region_points = [] |
| |
| for region in regions: |
| |
| avg_score = sum(item[3] for item in region) / len(region) |
| region_scores.append(avg_score) |
|
|
| |
| normalized_centers = [] |
| weights = [] |
| y_coords = set() |
| x_coords = set() |
|
|
| for y, x, _, score in region: |
| |
| center_y = (y + 0.5) / n_height |
| center_x = (x + 0.5) / n_width |
| normalized_centers.append((center_x, center_y)) |
| weights.append(score) |
|
|
| y_coords.add(center_y) |
| x_coords.add(center_x) |
|
|
| region_points.append(normalized_centers) |
|
|
| |
| if not rect_center: |
| |
| total_weight = sum(weights) |
| weighted_x = sum(nc[0] * w for nc, w in zip(normalized_centers, weights)) / total_weight |
| weighted_y = sum(nc[1] * w for nc, w in zip(normalized_centers, weights)) / total_weight |
| avg_center_x, avg_center_y = weighted_x, weighted_y |
| |
| |
| |
| else: |
| avg_center_x = sum(x_coords) / len(x_coords) |
| avg_center_y = sum(y_coords) / len(y_coords) |
| region_centers.append((avg_center_x, avg_center_y)) |
| |
| |
| sorted_indices = sorted(range(len(region_scores)), key=lambda i: region_scores[i], reverse=True) |
| sorted_scores = [region_scores[i] for i in sorted_indices] |
| sorted_centers = [region_centers[i] for i in sorted_indices] |
| sorted_points = [region_points[i] for i in sorted_indices] |
| best_point = sorted_centers[0] |
|
|
| if return_all_regions: |
| |
| |
| |
| |
| |
| return best_point, sorted_centers, sorted_scores, sorted_points |
| else: |
| return best_point |
|
|
|
|
| def inference(conversation, model, tokenizer, data_processor, logits_processor=None, use_placeholder=False, topk=5): |
| """ |
| conversation = [ |
| { |
| "role": "system", |
| "content": [ |
| { |
| "type": "text", |
| "text": grounding_system_message, |
| } |
| ] |
| }, |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "image", |
| "image": example["image"], # PIL.Image.Image or str to path |
| # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64," |
| }, |
| { |
| "type": "text", |
| "text": example["instruction"] |
| }, |
| ], |
| }, |
| ] |
| """ |
| if logits_processor is None: |
| logits_processor = ForceFollowTokensLogitsProcessor( |
| token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0], |
| forced_sequence=[ |
| tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] |
| ] |
| ) |
| |
| assiatant_starter = "" if not use_placeholder else "<|im_start|>assistant<|recipient|>os\npyautogui.click(<|pointer_start|><|pointer_pad|><|pointer_end|>)" |
|
|
| pred = { |
| "output_text": None, |
| "n_width": None, |
| "n_height": None, |
| "attn_scores": None, |
| "topk_points": None, |
| "topk_values": None, |
| "topk_points_all": None, |
| } |
|
|
| |
| text = data_processor.apply_chat_template(conversation, |
| tokenize=False, |
| add_generation_prompt=False, |
| chat_template=chat_template |
| ) |
| text += assiatant_starter |
|
|
| |
| image_inputs, video_inputs = process_vision_info(conversation) |
| inputs = data_processor(text=[text], |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt" |
| ) |
| inputs = inputs.to(model.device) |
|
|
| |
| results = model.generate(**inputs, |
| max_new_tokens=2048 if not use_placeholder else 1, |
| logits_processor=LogitsProcessorList([logits_processor]), |
| return_dict_in_generate=True, |
| output_hidden_states=True |
| ) |
|
|
|
|
| |
| input_ids = inputs["input_ids"][0] |
| generated_ids = results.sequences[0][len(input_ids):] |
| output_text = tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) |
| pred["output_text"] = output_text |
|
|
| |
| if use_placeholder: |
| pointer_pad_mask = (inputs["input_ids"][0] == model.config.pointer_pad_token_id) |
| else: |
| pointer_pad_mask = (generated_ids[:-1] == model.config.pointer_pad_token_id) |
|
|
| |
| if len(pointer_pad_mask) == 0: |
| return pred |
| |
| |
| if use_placeholder: |
| decoder_hidden_states = results.hidden_states[0][-1][0] |
| else: |
| decoder_hidden_states = [step_hidden_states[-1][0] for step_hidden_states in results.hidden_states[1:]] |
| decoder_hidden_states = torch.cat(decoder_hidden_states, dim=0) |
| decoder_hidden_states = decoder_hidden_states[pointer_pad_mask] |
|
|
| |
| |
| image_mask = (inputs["input_ids"][0] == tokenizer.encode("<|image_pad|>")[0]) |
| image_embeds = results.hidden_states[0][0][0][image_mask] |
|
|
| attn_scores, _ = model.multi_patch_pointer_head(image_embeds, decoder_hidden_states) |
| pred["attn_scores"] = attn_scores.tolist() |
|
|
| _, n_height, n_width = (inputs["image_grid_thw"][0] // model.visual.spatial_merge_size).tolist() |
| pred["n_width"] = n_width |
| pred["n_height"] = n_height |
|
|
| |
| best_point, region_points, region_scores, region_points_all = get_prediction_region_point(attn_scores, n_width, n_height, return_all_regions=True, rect_center=False) |
| topk_points = region_points[:topk] if len(region_points) > topk else region_points |
| topk_values = region_scores[:topk] if len(region_scores) > topk else region_scores |
| topk_points_all = region_points_all[:topk] if len(region_points_all) > topk else region_points_all |
| pred["topk_points"] = topk_points |
| pred["topk_values"] = topk_values |
| pred["topk_points_all"] = topk_points_all |
|
|
| return pred |