ayh015 commited on
Commit
73df34b
·
1 Parent(s): 6d4aa31

Update modifed code

Browse files
data/convsersation.py CHANGED
@@ -37,23 +37,27 @@ class Conversation:
37
  super().__init__()
38
  if system == '':
39
  self.system = f"""
40
- You are an AI assistant. You will be given an image that contains a main human subject.
41
  Task:
42
- Describe the visual evidence in the image that supports the subjects action, with an emphasis on human body parts and their interactions with objects.
43
 
44
  Hints:
45
- You may be given hints about (1) the action and (2) related objects and possible supporting body parts. You can use these hints, but you may also add other relevant evidence you observe.
46
 
47
  Required Constraints:
48
  - Start with ONE sentence that summarizes the main action in natural language.
49
  - When you mention any keypoint or body part, you MUST use names ONLY from: {COCO_KEYPOINT_NAME}.
50
  - Do NOT invent body-part names outside these sets (no synonyms, no paraphrases).
51
  - If you are unsure which name applies, either omit the body-part mention or choose the closest valid name from the lists.
52
- - Write your description in clear, concise sentences grounded in visible evidence.
 
 
53
 
54
  Optional Constraints :
 
 
55
  - Write naturally. Avoid repeating the same sentence pattern.
56
- - Keep each evidence item to one line. No redundant "both left/right do the same" unless necessary.
57
  """
58
  else:
59
  self.system = system
@@ -284,6 +288,134 @@ class Conversation_For_Clean_Evidence:
284
  """
285
  return prompt
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  @dataclasses.dataclass
288
  class Conversation_For_Action_Pharse:
289
  def __init__(self, system='', data_path=''):
@@ -416,4 +548,4 @@ class Conversation_For_COCO_Long_Description:
416
 
417
 
418
  if __name__ == "__main__":
419
- pass
 
37
  super().__init__()
38
  if system == '':
39
  self.system = f"""
40
+ You are an AI assistant for first-pass long-form HICO annotation. You will be given an image that contains a main human subject.
41
  Task:
42
+ Write a detailed long description of the visual evidence in the image that supports the subject's action, with an emphasis on human body parts, posture, spatial configuration, and interactions with objects.
43
 
44
  Hints:
45
+ You may be given hints about (1) the action and (2) related objects and possible supporting body parts. You should use these hints as anchors, and you may add other relevant visible evidence you observe.
46
 
47
  Required Constraints:
48
  - Start with ONE sentence that summarizes the main action in natural language.
49
  - When you mention any keypoint or body part, you MUST use names ONLY from: {COCO_KEYPOINT_NAME}.
50
  - Do NOT invent body-part names outside these sets (no synonyms, no paraphrases).
51
  - If you are unsure which name applies, either omit the body-part mention or choose the closest valid name from the lists.
52
+ - The description must be long and detailed enough to serve as a first-pass annotation for later refinement.
53
+ - Include as many relevant supporting details as are visibly justified, especially about contact, pose, orientation, support, and object interaction.
54
+ - Write your description in clear, natural sentences grounded in visible evidence.
55
 
56
  Optional Constraints :
57
+ - Prefer a rich multi-sentence paragraph rather than a short caption.
58
+ - Cover multiple cues when available, such as limb placement, body balance, joint bending, contact points, and relative position to the object.
59
  - Write naturally. Avoid repeating the same sentence pattern.
60
+ - If both sides contribute differently, describe them separately.
61
  """
62
  else:
63
  self.system = system
 
288
  """
289
  return prompt
290
 
291
+ @dataclasses.dataclass
292
+ class Conversation_examiner:
293
+ def __init__(self, system='', data_path=''):
294
+ super().__init__()
295
+ if system == '':
296
+ self.system = f"""
297
+ You are a strict checker and final editor for HICO action annotations.
298
+
299
+ You will be given:
300
+ - The ground-truth HICO action hint as [VERB, OBJECT].
301
+ - Part-state hints derived from annotation labels.
302
+ - One or more candidate texts, such as a long description, a short description, or an evidence-only description.
303
+
304
+ Your task:
305
+ - Judge whether the candidate texts are consistent with the target action.
306
+ - Check whether the descriptions are grounded in plausible visible body-part evidence.
307
+ - Check whether any mentioned body parts use valid COCO keypoint names only: {COCO_KEYPOINT_NAME}.
308
+ - Detect unsupported claims, contradictions, object/action mismatches, left/right mistakes, and hallucinated joints or interactions.
309
+ - Produce a final checked description after resolving any issues you can fix from the provided candidates and hints.
310
+
311
+ Important checking rules:
312
+ - The target action is defined by the provided HICO hint, not by the candidate text.
313
+ - If a candidate text conflicts with the target action, fix the final checked description so it aligns with the target action.
314
+ - If a candidate text includes body-part terms outside the allowed keypoint list, replace them with valid names when possible and record the issue.
315
+ - If evidence is too vague, missing, or unrelated to the target action, remove unsupported content from the final checked description and record the issue.
316
+ - Pay special attention to left/right consistency. If the candidate confuses left and right, or assigns evidence to the wrong side, correct it when the correct side is supported by the provided candidates and hints; otherwise remove the uncertain side-specific claim and record the issue.
317
+ - Do not keep any joint claim that is not visible, not inferable from the provided evidence, or appears hallucinated. If a joint or body-part interaction cannot be supported, remove it and record the issue.
318
+ - Do not invent new visual evidence that is not supported by the provided candidates and hints.
319
+ - The final checked description should be concise, natural, and reliable.
320
+ - Prefer the strongest grounded evidence among the provided candidates.
321
+ - When side-specific evidence is uncertain, prefer a conservative description over a risky one.
322
+
323
+ Output format:
324
+ Return plain text in exactly this structure.
325
+
326
+ Verdict: PASS or REVISED
327
+ Action alignment: one short sentence
328
+ Evidence grounding: one short sentence
329
+ Keypoint-name validity: one short sentence
330
+ Checked description:
331
+ <final checked description>
332
+ Issues:
333
+ - item 1
334
+ - item 2
335
+
336
+ If there are no issues, write:
337
+ Issues:
338
+ - None
339
+ """
340
+ else:
341
+ self.system = system
342
+
343
+ self.hoi_reference = read_hoi_file_2_dict(os.path.join(data_path, 'Configs/hico_hoi_list.txt'))
344
+ self.part_state_reference = read_part_state_file_2_dict(os.path.join(data_path, 'Configs/Part_State_76.txt'))
345
+
346
+ def _replace_part_names(self, text):
347
+ REPL = {
348
+ "hand": "wrist",
349
+ "hands": "wrists",
350
+ "foot": "ankle",
351
+ "feet": "ankles",
352
+ }
353
+ pattern = re.compile(r"\b(" + "|".join(map(re.escape, REPL.keys())) + r")\b", re.IGNORECASE)
354
+ def _sub(m):
355
+ w = m.group(0)
356
+ out = REPL[w.lower()]
357
+ if w[0].isupper():
358
+ out = out.capitalize()
359
+ return out
360
+ return pattern.sub(_sub, text)
361
+
362
+ def _humanpart2word(self, action_labels):
363
+ action_labels_in_words = []
364
+ part_state_keys = list(self.part_state_reference.keys())
365
+ for d in action_labels:
366
+ human_part_id = d['human_part']
367
+ part_state_id = d['partstate']
368
+
369
+ part_name = PART_ORDER[human_part_id]
370
+ for key in part_state_keys:
371
+ if key in part_name:
372
+ states = self.part_state_reference[key]
373
+ part_state = states[part_state_id]
374
+
375
+ part_name = self._replace_part_names(part_name)
376
+ action_labels_in_words.append([part_name, part_state])
377
+ return action_labels_in_words
378
+
379
+ def _actionid2word(self, hoi_id):
380
+ obj, act = self.hoi_reference[hoi_id]
381
+ return obj, act
382
+
383
+ def get_prompt(self, meta):
384
+ hoi_id = meta['hoi_id']
385
+ obj_in_word, act_in_word = self._actionid2word(hoi_id)
386
+ action_labels = meta['action_labels']
387
+ action_labels_in_words = self._humanpart2word(action_labels)
388
+
389
+ long_description = self._replace_part_names(meta.get('description', ''))
390
+ refined_description = self._replace_part_names(meta.get('refined_description', ''))
391
+ short_description = self._replace_part_names(meta.get('short_description', ''))
392
+ action_description = self._replace_part_names(meta.get('action_description', ''))
393
+ evidence_description = self._replace_part_names(meta.get('evidence_description', ''))
394
+
395
+ prompt = f"""
396
+ Target action hint: [{act_in_word}, {obj_in_word}]
397
+ Part-state hints:
398
+ {action_labels_in_words}
399
+
400
+ Candidate long description:
401
+ {long_description if long_description else '[Missing]'}
402
+
403
+ Candidate refined description:
404
+ {refined_description if refined_description else '[Missing]'}
405
+
406
+ Candidate short description:
407
+ {short_description if short_description else '[Missing]'}
408
+
409
+ Candidate action description:
410
+ {action_description if action_description else '[Missing]'}
411
+
412
+ Candidate evidence description:
413
+ {evidence_description if evidence_description else '[Missing]'}
414
+
415
+ Check the candidates against the target action and part-state hints, produce the final checked description, and then follow the required output format exactly.
416
+ """
417
+ return prompt
418
+
419
  @dataclasses.dataclass
420
  class Conversation_For_Action_Pharse:
421
  def __init__(self, system='', data_path=''):
 
548
 
549
 
550
  if __name__ == "__main__":
551
+ pass
data/dataset_for_clean_descrip.py CHANGED
@@ -34,6 +34,8 @@ class PoseHICODetDataset(Dataset):
34
  """Dataset for supervised fine-tuning."""
35
  def __init__(self, data_path: str,
36
  multimodal_cfg: dict,
 
 
37
  ):
38
  super(PoseHICODetDataset, self).__init__()
39
  logging.warning("Loading data...")
@@ -43,7 +45,9 @@ class PoseHICODetDataset(Dataset):
43
  self.pixel_std = 200
44
  self.num_joints = 17
45
  self.num_joints_full_body = 136
46
- self.list_data_dict = self._load_json('./outputs/merged_labels.json')
 
 
47
 
48
  json_path = os.path.join(data_path, "Annotation/hico-det-instance-level/hico-det-training-set-instance-level.json")
49
  with open(json_path, "r", encoding="utf-8") as f:
 
34
  """Dataset for supervised fine-tuning."""
35
  def __init__(self, data_path: str,
36
  multimodal_cfg: dict,
37
+ annotation_path: str = './outputs/merged_labels.json',
38
+ max_samples: int = 0,
39
  ):
40
  super(PoseHICODetDataset, self).__init__()
41
  logging.warning("Loading data...")
 
45
  self.pixel_std = 200
46
  self.num_joints = 17
47
  self.num_joints_full_body = 136
48
+ self.list_data_dict = self._load_json(annotation_path)
49
+ if max_samples > 0:
50
+ self.list_data_dict = self.list_data_dict[:max_samples]
51
 
52
  json_path = os.path.join(data_path, "Annotation/hico-det-instance-level/hico-det-training-set-instance-level.json")
53
  with open(json_path, "r", encoding="utf-8") as f:
data/pose_hicodet.py CHANGED
@@ -34,6 +34,7 @@ class PoseHICODetDataset(Dataset):
34
  """Dataset for supervised fine-tuning."""
35
  def __init__(self, data_path: str,
36
  multimodal_cfg: dict,
 
37
  ):
38
  super(PoseHICODetDataset, self).__init__()
39
  logging.warning("Loading data...")
@@ -43,6 +44,7 @@ class PoseHICODetDataset(Dataset):
43
  self.pixel_std = 200
44
  self.num_joints = 17
45
  self.num_joints_full_body = 136
 
46
  self.list_data_dict = self._load_data(data_path)
47
 
48
 
@@ -134,6 +136,11 @@ class PoseHICODetDataset(Dataset):
134
  'hoi_obj': hoi_obj,
135
  })
136
  instance_id += 1
 
 
 
 
 
137
 
138
  logging.warning("The number of training samples is {}".format(len(list_data_dict)))
139
  logging.warning("Formatting inputs...Skip in lazy mode")
 
34
  """Dataset for supervised fine-tuning."""
35
  def __init__(self, data_path: str,
36
  multimodal_cfg: dict,
37
+ max_samples: int = 0,
38
  ):
39
  super(PoseHICODetDataset, self).__init__()
40
  logging.warning("Loading data...")
 
44
  self.pixel_std = 200
45
  self.num_joints = 17
46
  self.num_joints_full_body = 136
47
+ self.max_samples = max_samples
48
  self.list_data_dict = self._load_data(data_path)
49
 
50
 
 
136
  'hoi_obj': hoi_obj,
137
  })
138
  instance_id += 1
139
+ if self.max_samples > 0 and len(list_data_dict) >= self.max_samples:
140
+ logging.warning("Reached max_samples={}, stopping early.".format(self.max_samples))
141
+ logging.warning("The number of training samples is {}".format(len(list_data_dict)))
142
+ logging.warning("Formatting inputs...Skip in lazy mode")
143
+ return list_data_dict
144
 
145
  logging.warning("The number of training samples is {}".format(len(list_data_dict)))
146
  logging.warning("Formatting inputs...Skip in lazy mode")
scripts/examine_hico.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IDX=1
2
+ export PYTHONPATH=$PYTHONPATH:./
3
+
4
+ data_path=../datasets/HICO-Det
5
+ model_path=./model_weights/qwen3_8b_vl_instruct
6
+ annotation_path=./outputs/merged_labels.json
7
+ output_dir=outputs/examiner
8
+
9
+ if [ -d ${output_dir} ]; then
10
+ echo "dir already exists"
11
+ else
12
+ mkdir -p ${output_dir}
13
+ fi
14
+
15
+ CUDA_VISIBLE_DEVICES=$IDX OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=1 --master_port=25007 \
16
+ tools/examine_hico.py \
17
+ --model-path ${model_path} \
18
+ --data-path ${data_path} \
19
+ --annotation-path ${annotation_path} \
20
+ --output-dir ${output_dir} \
scripts/pipeline_hico.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ export PYTHONPATH="${PYTHONPATH:-}:./"
5
+
6
+ DATA_PATH=../datasets/HICO-Det
7
+
8
+ LONG_MODEL_PATH=./model_weights/qwen3_8b_vl_instruct
9
+ REFINE_MODEL_PATH=./model_weights/qwen3_8b_vl_instruct
10
+ EXAMINE_MODEL_PATH=./model_weights/qwen3_8b_vl_instruct
11
+
12
+ LONG_GPU_IDS=0
13
+ REFINE_GPU_IDS=0
14
+ EXAMINE_GPU_IDS=0
15
+
16
+ LONG_NPROC=1
17
+ REFINE_NPROC=1
18
+ EXAMINE_NPROC=1
19
+
20
+ LONG_OUT_DIR=outputs/pipeline/long
21
+ REFINE_OUT_DIR=outputs/pipeline/refine
22
+ EXAMINE_OUT_DIR=outputs/pipeline/examine
23
+
24
+ MERGED_LONG_JSON=outputs/pipeline/merged_long.json
25
+ MERGED_REFINE_JSON=outputs/pipeline/merged_refine.json
26
+ MERGED_EXAMINE_JSON=outputs/pipeline/merged_examine.json
27
+
28
+ mkdir -p "${LONG_OUT_DIR}" "${REFINE_OUT_DIR}" "${EXAMINE_OUT_DIR}"
29
+
30
+ CUDA_VISIBLE_DEVICES=${LONG_GPU_IDS} OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=${LONG_NPROC} --master_port=25011 \
31
+ tools/annotate_hico.py \
32
+ --model-path "${LONG_MODEL_PATH}" \
33
+ --data-path "${DATA_PATH}" \
34
+ --output-dir "${LONG_OUT_DIR}"
35
+
36
+ python3 tools/merge_json_outputs.py \
37
+ --input-dir "${LONG_OUT_DIR}" \
38
+ --pattern "labels_*.json" \
39
+ --output-path "${MERGED_LONG_JSON}"
40
+
41
+ CUDA_VISIBLE_DEVICES=${REFINE_GPU_IDS} OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=${REFINE_NPROC} --master_port=25012 \
42
+ tools/refine_hico.py \
43
+ --model-path "${REFINE_MODEL_PATH}" \
44
+ --data-path "${DATA_PATH}" \
45
+ --annotation-path "${MERGED_LONG_JSON}" \
46
+ --output-dir "${REFINE_OUT_DIR}"
47
+
48
+ python3 tools/merge_json_outputs.py \
49
+ --input-dir "${REFINE_OUT_DIR}" \
50
+ --pattern "refine_labels_*.json" \
51
+ --output-path "${MERGED_REFINE_JSON}"
52
+
53
+ CUDA_VISIBLE_DEVICES=${EXAMINE_GPU_IDS} OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=${EXAMINE_NPROC} --master_port=25013 \
54
+ tools/examine_hico.py \
55
+ --model-path "${EXAMINE_MODEL_PATH}" \
56
+ --data-path "${DATA_PATH}" \
57
+ --annotation-path "${MERGED_REFINE_JSON}" \
58
+ --output-dir "${EXAMINE_OUT_DIR}"
59
+
60
+ python3 tools/merge_json_outputs.py \
61
+ --input-dir "${EXAMINE_OUT_DIR}" \
62
+ --pattern "examiner_labels_*.json" \
63
+ --output-path "${MERGED_EXAMINE_JSON}"
64
+
65
+ echo "Pipeline complete."
66
+ echo "Long descriptions: ${MERGED_LONG_JSON}"
67
+ echo "Refined descriptions: ${MERGED_REFINE_JSON}"
68
+ echo "Examiner results: ${MERGED_EXAMINE_JSON}"
scripts/refine_hico.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IDX=1
2
+ export PYTHONPATH=$PYTHONPATH:./
3
+
4
+ data_path=../datasets/HICO-Det
5
+ model_path=./model_weights/qwen3_8b_vl_instruct
6
+ annotation_path=./outputs/merged_labels.json
7
+ output_dir=outputs/refine
8
+
9
+ if [ -d ${output_dir} ]; then
10
+ echo "dir already exists"
11
+ else
12
+ mkdir -p ${output_dir}
13
+ fi
14
+
15
+ CUDA_VISIBLE_DEVICES=$IDX OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=1 --master_port=25008 \
16
+ tools/refine_hico.py \
17
+ --model-path ${model_path} \
18
+ --data-path ${data_path} \
19
+ --annotation-path ${annotation_path} \
20
+ --output-dir ${output_dir} \
tools/annotate_hico.py CHANGED
@@ -15,8 +15,7 @@ from data.convsersation import Conversation
15
  import re
16
  from dataclasses import dataclass
17
 
18
- from transformers import Qwen3VLForConditionalGeneration
19
- from transformers import AutoTokenizer, AutoConfig, AutoProcessor
20
 
21
  def disable_torch_init():
22
  """
@@ -29,6 +28,30 @@ import os, json
29
  import torch
30
  import torch.distributed as dist
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def gather_labels_and_save(labels, output_path):
33
  # Make sure dist is initialized (torchrun / deepspeed / accelerate usually does this)
34
  world_size = dist.get_world_size()
@@ -85,11 +108,11 @@ class DataCollatorForSupervisedDataset(object):
85
  tokenize=False,
86
  add_generation_prompt=True)
87
  for m in messages]
88
- batch_tensors = self.processor(
89
- text=prompts,
 
90
  images=batch_images,
91
- return_tensors="pt",
92
- padding=True
93
  )
94
  return batch_tensors, result_meta
95
 
@@ -104,48 +127,42 @@ def worker(model, processor, dataset, args, output_dir):
104
  sub_dataset = torch.utils.data.Subset(dataset, indices)
105
  batch_size = 1
106
  data_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path))
107
- labels = []
108
-
109
- for batch_tensors, result_meta in tqdm(data_loader):
110
-
111
- input_ids = batch_tensors['input_ids'].cuda()
112
- batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)}
113
- with torch.inference_mode():
114
- output_dict = model.generate(do_sample=False,
115
- output_scores=True,
116
- return_dict_in_generate=True,
117
- max_new_tokens=1600,
118
- output_logits=True,
119
- **batch_tensors,)
120
-
121
- output_ids = output_dict['sequences']
122
-
123
- for input_id, output_id, meta in zip(input_ids, output_ids, result_meta):
124
- input_token_len = input_id.shape[0]
125
- n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item()
126
- if n_diff_input_output > 0:
127
- print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
128
- output = processor.tokenizer.batch_decode(output_id[input_token_len:].unsqueeze(0), skip_special_tokens=True)[0]
129
- labels.append({
130
- 'file_name': meta['file_name'],
131
- 'image_id': meta['image_id'],
132
- 'instance_id': meta['instance_id'],
133
- 'keypoints': meta['joints_3d'].reshape(-1).tolist(),
134
- 'vis': meta['joints_3d_vis'].reshape(-1).tolist(),
135
- 'im_height': meta['hoi_obj']['height'],
136
- 'im_width': meta['hoi_obj']['width'],
137
- 'hoi_id': meta['hoi_obj']['hoi_id'],
138
- 'human_bbox': meta['hoi_obj']['human_bbox'],
139
- 'object_bbox': meta['hoi_obj']['object_bbox'],
140
- 'action_labels': meta['hoi_obj']['action_labels'],
141
- 'description': output,
142
- })
143
-
144
-
145
- local_rank = int(os.environ.get("LOCAL_RANK", "0"))
146
- output_path = os.path.join(args.output_dir, f'labels_{local_rank}.json')
147
- with open(output_path, "w", encoding="utf-8") as f:
148
- json.dump(labels, f, ensure_ascii=False, indent=2)
149
 
150
  def eval_model(args):
151
  torch.distributed.init_process_group(backend='nccl')
@@ -156,25 +173,22 @@ def eval_model(args):
156
  torch.cuda.set_device(rank)
157
 
158
  disable_torch_init()
159
- model = Qwen3VLForConditionalGeneration.from_pretrained(
160
- args.model_path,
161
- torch_dtype=torch.bfloat16,
162
- trust_remote_code=True
 
163
  )
 
164
  model = model.cuda()
165
  model.eval()
166
-
167
- processor = AutoProcessor.from_pretrained(
168
- args.model_path,
169
- trust_remote_code=True)
170
- processor.tokenizer.padding_side = "left"
171
- processor.tokenizer.pad_token = processor.tokenizer.eos_token
172
 
173
  dataset = PoseHICODetDataset(
174
  data_path=args.data_path,
175
  multimodal_cfg=dict(image_folder=os.path.join(args.data_path, 'Images/images/train2015'),
176
  data_augmentation=False,
177
- image_size=336,),)
 
178
  worker(model, processor, dataset, args, args.output_dir)
179
 
180
  if __name__ == "__main__":
@@ -182,7 +196,10 @@ if __name__ == "__main__":
182
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
183
  parser.add_argument("--data-path", type=str, default="")
184
  parser.add_argument("--output-dir", type=str, default="")
 
 
 
185
  args = parser.parse_args()
186
 
187
  eval_model(args)
188
-
 
15
  import re
16
  from dataclasses import dataclass
17
 
18
+ from tools.vlm_backend import build_batch_tensors, decode_generated_text, load_model_and_processor
 
19
 
20
  def disable_torch_init():
21
  """
 
28
  import torch
29
  import torch.distributed as dist
30
 
31
+ class StreamingJsonArrayWriter:
32
+ def __init__(self, output_path):
33
+ self.output_path = output_path
34
+ self.file = None
35
+ self.is_first = True
36
+
37
+ def __enter__(self):
38
+ self.file = open(self.output_path, "w", encoding="utf-8")
39
+ self.file.write("[\n")
40
+ self.file.flush()
41
+ return self
42
+
43
+ def write(self, item):
44
+ if not self.is_first:
45
+ self.file.write(",\n")
46
+ json.dump(item, self.file, ensure_ascii=False, indent=2)
47
+ self.file.flush()
48
+ self.is_first = False
49
+
50
+ def __exit__(self, exc_type, exc_val, exc_tb):
51
+ if self.file is not None:
52
+ self.file.write("\n]\n")
53
+ self.file.close()
54
+
55
  def gather_labels_and_save(labels, output_path):
56
  # Make sure dist is initialized (torchrun / deepspeed / accelerate usually does this)
57
  world_size = dist.get_world_size()
 
108
  tokenize=False,
109
  add_generation_prompt=True)
110
  for m in messages]
111
+ batch_tensors = build_batch_tensors(
112
+ processor=self.processor,
113
+ prompts=batch_prompts,
114
  images=batch_images,
115
+ system_prompt=self.conv.system,
 
116
  )
117
  return batch_tensors, result_meta
118
 
 
127
  sub_dataset = torch.utils.data.Subset(dataset, indices)
128
  batch_size = 1
129
  data_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path))
130
+ output_path = os.path.join(args.output_dir, f'labels_{rank}.json')
131
+
132
+ with StreamingJsonArrayWriter(output_path) as writer:
133
+ for batch_tensors, result_meta in tqdm(data_loader):
134
+ input_ids = batch_tensors['input_ids'].cuda()
135
+ batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)}
136
+ with torch.inference_mode():
137
+ output_dict = model.generate(do_sample=False,
138
+ output_scores=True,
139
+ return_dict_in_generate=True,
140
+ max_new_tokens=1600,
141
+ output_logits=True,
142
+ **batch_tensors,)
143
+
144
+ output_ids = output_dict['sequences']
145
+
146
+ for input_id, output_id, meta in zip(input_ids, output_ids, result_meta):
147
+ input_token_len = input_id.shape[0]
148
+ n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item()
149
+ if n_diff_input_output > 0:
150
+ print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
151
+ output = decode_generated_text(processor, output_id, input_id)
152
+ writer.write({
153
+ 'file_name': meta['file_name'],
154
+ 'image_id': meta['image_id'],
155
+ 'instance_id': meta['instance_id'],
156
+ 'keypoints': meta['joints_3d'].reshape(-1).tolist(),
157
+ 'vis': meta['joints_3d_vis'].reshape(-1).tolist(),
158
+ 'im_height': meta['hoi_obj']['height'],
159
+ 'im_width': meta['hoi_obj']['width'],
160
+ 'hoi_id': meta['hoi_obj']['hoi_id'],
161
+ 'human_bbox': meta['hoi_obj']['human_bbox'],
162
+ 'object_bbox': meta['hoi_obj']['object_bbox'],
163
+ 'action_labels': meta['hoi_obj']['action_labels'],
164
+ 'description': output,
165
+ })
 
 
 
 
 
 
166
 
167
  def eval_model(args):
168
  torch.distributed.init_process_group(backend='nccl')
 
173
  torch.cuda.set_device(rank)
174
 
175
  disable_torch_init()
176
+ backend_name, model, processor = load_model_and_processor(
177
+ model_path=args.model_path,
178
+ backend=args.model_backend,
179
+ torch_dtype=args.torch_dtype,
180
+ trust_remote_code=True,
181
  )
182
+ print(f'Using model backend: {backend_name}')
183
  model = model.cuda()
184
  model.eval()
 
 
 
 
 
 
185
 
186
  dataset = PoseHICODetDataset(
187
  data_path=args.data_path,
188
  multimodal_cfg=dict(image_folder=os.path.join(args.data_path, 'Images/images/train2015'),
189
  data_augmentation=False,
190
+ image_size=336,),
191
+ max_samples=args.max_samples,)
192
  worker(model, processor, dataset, args, args.output_dir)
193
 
194
  if __name__ == "__main__":
 
196
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
197
  parser.add_argument("--data-path", type=str, default="")
198
  parser.add_argument("--output-dir", type=str, default="")
199
+ parser.add_argument("--max-samples", type=int, default=0)
200
+ parser.add_argument("--model-backend", type=str, default="auto")
201
+ parser.add_argument("--torch-dtype", type=str, default="bfloat16")
202
  args = parser.parse_args()
203
 
204
  eval_model(args)
205
+
tools/clean_initial_annotation.py CHANGED
@@ -15,8 +15,7 @@ from data.convsersation import Conversation_For_Action_Pharse as Conversation
15
  import re
16
  from dataclasses import dataclass
17
 
18
- from transformers import Qwen3VLForConditionalGeneration
19
- from transformers import AutoTokenizer, AutoConfig, AutoProcessor
20
 
21
  def disable_torch_init():
22
  """
@@ -28,6 +27,30 @@ def disable_torch_init():
28
  import os, json
29
  import torch
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @dataclass
32
  class DataCollatorForSupervisedDataset(object):
33
  def __init__(self, processor, data_path):
@@ -62,15 +85,11 @@ class DataCollatorForSupervisedDataset(object):
62
  "text": prompt},]},
63
  ])
64
 
65
- prompts = [self.processor.apply_chat_template(m,
66
- tokenize=False,
67
- add_generation_prompt=True)
68
- for m in messages]
69
- batch_tensors = self.processor(
70
- text=prompts,
71
  images=batch_images,
72
- return_tensors="pt",
73
- padding=True
74
  )
75
  return batch_tensors, result_meta
76
 
@@ -85,39 +104,31 @@ def worker(model, processor, dataset, args, output_dir):
85
  sub_dataset = torch.utils.data.Subset(dataset, indices)
86
  batch_size = 16
87
  data_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path))
88
- labels = []
89
 
90
- for batch_tensors, result_meta in tqdm(data_loader):
91
-
92
- input_ids = batch_tensors['input_ids'].cuda()
93
- batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)}
94
- with torch.inference_mode():
95
- output_dict = model.generate(do_sample=False,
96
- output_scores=True,
97
- return_dict_in_generate=True,
98
- max_new_tokens=1600,
99
- output_logits=True,
100
- **batch_tensors,)
101
 
102
- output_ids = output_dict['sequences']
103
-
104
- for input_id, output_id, meta in zip(input_ids, output_ids, result_meta):
105
- input_token_len = input_id.shape[0]
106
- n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item()
107
- if n_diff_input_output > 0:
108
- print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
109
- #input_text = processor.tokenizer.batch_decode(output_id[:input_token_len].unsqueeze(0), skip_special_tokens=True)[0]
110
- output = processor.tokenizer.batch_decode(output_id[input_token_len:].unsqueeze(0), skip_special_tokens=True)[0]
111
- # print(output)
112
- # import pdb;pdb.set_trace()
113
- meta['action_description'] = output
114
- #import pdb;pdb.set_trace()
115
- labels.append(meta)
116
-
117
- local_rank = int(os.environ.get("LOCAL_RANK", "0"))
118
- output_path = os.path.join(args.output_dir, f'labels_{local_rank}.json')
119
- with open(output_path, "w", encoding="utf-8") as f:
120
- json.dump(labels, f, ensure_ascii=False, indent=2)
 
121
 
122
  def eval_model(args):
123
  torch.distributed.init_process_group(backend='nccl')
@@ -128,19 +139,15 @@ def eval_model(args):
128
  torch.cuda.set_device(rank)
129
 
130
  disable_torch_init()
131
- model = Qwen3VLForConditionalGeneration.from_pretrained(
132
- args.model_path,
133
- torch_dtype=torch.bfloat16,
134
- trust_remote_code=True
 
135
  )
 
136
  model = model.cuda()
137
  model.eval()
138
-
139
- processor = AutoProcessor.from_pretrained(
140
- args.model_path,
141
- trust_remote_code=True)
142
- processor.tokenizer.padding_side = "left"
143
- processor.tokenizer.pad_token = processor.tokenizer.eos_token
144
 
145
  dataset = PoseHICODetDataset(
146
  data_path=args.data_path,
@@ -154,7 +161,9 @@ if __name__ == "__main__":
154
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
155
  parser.add_argument("--data-path", type=str, default="")
156
  parser.add_argument("--output-dir", type=str, default="")
 
 
157
  args = parser.parse_args()
158
 
159
  eval_model(args)
160
-
 
15
  import re
16
  from dataclasses import dataclass
17
 
18
+ from tools.vlm_backend import build_batch_tensors, decode_generated_text, load_model_and_processor
 
19
 
20
  def disable_torch_init():
21
  """
 
27
  import os, json
28
  import torch
29
 
30
+ class StreamingJsonArrayWriter:
31
+ def __init__(self, output_path):
32
+ self.output_path = output_path
33
+ self.file = None
34
+ self.is_first = True
35
+
36
+ def __enter__(self):
37
+ self.file = open(self.output_path, "w", encoding="utf-8")
38
+ self.file.write("[\n")
39
+ self.file.flush()
40
+ return self
41
+
42
+ def write(self, item):
43
+ if not self.is_first:
44
+ self.file.write(",\n")
45
+ json.dump(item, self.file, ensure_ascii=False, indent=2)
46
+ self.file.flush()
47
+ self.is_first = False
48
+
49
+ def __exit__(self, exc_type, exc_val, exc_tb):
50
+ if self.file is not None:
51
+ self.file.write("\n]\n")
52
+ self.file.close()
53
+
54
  @dataclass
55
  class DataCollatorForSupervisedDataset(object):
56
  def __init__(self, processor, data_path):
 
85
  "text": prompt},]},
86
  ])
87
 
88
+ batch_tensors = build_batch_tensors(
89
+ processor=self.processor,
90
+ prompts=batch_prompts,
 
 
 
91
  images=batch_images,
92
+ system_prompt=self.conv.system,
 
93
  )
94
  return batch_tensors, result_meta
95
 
 
104
  sub_dataset = torch.utils.data.Subset(dataset, indices)
105
  batch_size = 16
106
  data_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path))
107
+ output_path = os.path.join(args.output_dir, f'labels_{rank}.json')
108
 
109
+ with StreamingJsonArrayWriter(output_path) as writer:
110
+ for batch_tensors, result_meta in tqdm(data_loader):
 
 
 
 
 
 
 
 
 
111
 
112
+ input_ids = batch_tensors['input_ids'].cuda()
113
+ batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)}
114
+ with torch.inference_mode():
115
+ output_dict = model.generate(do_sample=False,
116
+ output_scores=True,
117
+ return_dict_in_generate=True,
118
+ max_new_tokens=1600,
119
+ output_logits=True,
120
+ **batch_tensors,)
121
+
122
+ output_ids = output_dict['sequences']
123
+
124
+ for input_id, output_id, meta in zip(input_ids, output_ids, result_meta):
125
+ input_token_len = input_id.shape[0]
126
+ n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item()
127
+ if n_diff_input_output > 0:
128
+ print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
129
+ output = decode_generated_text(processor, output_id, input_id)
130
+ meta['action_description'] = output
131
+ writer.write(meta)
132
 
133
  def eval_model(args):
134
  torch.distributed.init_process_group(backend='nccl')
 
139
  torch.cuda.set_device(rank)
140
 
141
  disable_torch_init()
142
+ backend_name, model, processor = load_model_and_processor(
143
+ model_path=args.model_path,
144
+ backend=args.model_backend,
145
+ torch_dtype=args.torch_dtype,
146
+ trust_remote_code=True,
147
  )
148
+ print(f'Using model backend: {backend_name}')
149
  model = model.cuda()
150
  model.eval()
 
 
 
 
 
 
151
 
152
  dataset = PoseHICODetDataset(
153
  data_path=args.data_path,
 
161
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
162
  parser.add_argument("--data-path", type=str, default="")
163
  parser.add_argument("--output-dir", type=str, default="")
164
+ parser.add_argument("--model-backend", type=str, default="auto")
165
+ parser.add_argument("--torch-dtype", type=str, default="bfloat16")
166
  args = parser.parse_args()
167
 
168
  eval_model(args)
169
+
tools/examine_hico.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import re
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.utils.data import DataLoader
10
+
11
+ from data.dataset_for_clean_descrip import PoseHICODetDataset
12
+ from data.convsersation import Conversation_examiner as Conversation
13
+
14
+ from dataclasses import dataclass
15
+
16
+ from tools.vlm_backend import build_batch_tensors, decode_generated_text, load_model_and_processor
17
+
18
+ class StreamingJsonArrayWriter:
19
+ def __init__(self, output_path):
20
+ self.output_path = output_path
21
+ self.file = None
22
+ self.is_first = True
23
+
24
+ def __enter__(self):
25
+ self.file = open(self.output_path, "w", encoding="utf-8")
26
+ self.file.write("[\n")
27
+ self.file.flush()
28
+ return self
29
+
30
+ def write(self, item):
31
+ if not self.is_first:
32
+ self.file.write(",\n")
33
+ json.dump(item, self.file, ensure_ascii=False, indent=2)
34
+ self.file.flush()
35
+ self.is_first = False
36
+
37
+ def __exit__(self, exc_type, exc_val, exc_tb):
38
+ if self.file is not None:
39
+ self.file.write("\n]\n")
40
+ self.file.close()
41
+
42
+
43
+ def disable_torch_init():
44
+ """
45
+ Disable the redundant torch default initialization to accelerate model creation.
46
+ """
47
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
48
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
49
+
50
+
51
+ def extract_checked_description(text):
52
+ match = re.search(
53
+ r"Checked description:\s*(.*?)\s*Issues:\s*",
54
+ text,
55
+ flags=re.DOTALL
56
+ )
57
+ if match:
58
+ return match.group(1).strip()
59
+ return ""
60
+
61
+
62
+ @dataclass
63
+ class DataCollatorForSupervisedDataset(object):
64
+ def __init__(self, processor, data_path):
65
+ self.processor = processor
66
+ self.conv = Conversation(
67
+ system='',
68
+ data_path=data_path
69
+ )
70
+
71
+ def __call__(self, data_dicts):
72
+ batch_prompts = []
73
+ batch_images = []
74
+ result_meta = []
75
+
76
+ for data_dict in data_dicts:
77
+ batch_images.append(data_dict['image'])
78
+ batch_prompts.append(self.conv.get_prompt(data_dict['meta']))
79
+ result_meta.append(data_dict['meta'])
80
+
81
+ batch_tensors = build_batch_tensors(
82
+ processor=self.processor,
83
+ prompts=batch_prompts,
84
+ images=batch_images,
85
+ system_prompt=self.conv.system,
86
+ )
87
+ return batch_tensors, result_meta
88
+
89
+
90
+ @torch.no_grad()
91
+ def worker(model, processor, dataset, args):
92
+ rank = int(os.environ["LOCAL_RANK"])
93
+ world_size = int(os.environ["WORLD_SIZE"])
94
+ indices = list(range(rank, len(dataset), world_size))
95
+ print("==>" + " Worker {} Started, responsible for {} images".format(rank, len(indices)))
96
+
97
+ sub_dataset = torch.utils.data.Subset(dataset, indices)
98
+ batch_size = args.batch_size
99
+ data_loader = DataLoader(
100
+ sub_dataset,
101
+ batch_size=batch_size,
102
+ shuffle=False,
103
+ num_workers=0,
104
+ collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path)
105
+ )
106
+ output_path = os.path.join(args.output_dir, f'examiner_labels_{rank}.json')
107
+
108
+ with StreamingJsonArrayWriter(output_path) as writer:
109
+ for batch_tensors, result_meta in tqdm(data_loader):
110
+ input_ids = batch_tensors['input_ids'].cuda()
111
+ batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)}
112
+ with torch.inference_mode():
113
+ output_dict = model.generate(
114
+ do_sample=False,
115
+ output_scores=True,
116
+ return_dict_in_generate=True,
117
+ max_new_tokens=args.max_new_tokens,
118
+ output_logits=True,
119
+ **batch_tensors,
120
+ )
121
+ output_ids = output_dict['sequences']
122
+
123
+ for input_id, output_id, meta in zip(input_ids, output_ids, result_meta):
124
+ input_token_len = input_id.shape[0]
125
+ n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item()
126
+ if n_diff_input_output > 0:
127
+ print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
128
+ output = decode_generated_text(processor, output_id, input_id)
129
+ meta['examiner_result'] = output
130
+ meta['final_description'] = extract_checked_description(output)
131
+ writer.write(meta)
132
+
133
+
134
+ def eval_model(args):
135
+ dist.init_process_group(backend='nccl')
136
+ rank = int(os.environ["LOCAL_RANK"])
137
+ world_size = int(os.environ["WORLD_SIZE"])
138
+
139
+ print('Init process group: world_size: {}, rank: {}'.format(world_size, rank))
140
+ torch.cuda.set_device(rank)
141
+
142
+ disable_torch_init()
143
+ backend_name, model, processor = load_model_and_processor(
144
+ model_path=args.model_path,
145
+ backend=args.model_backend,
146
+ torch_dtype=args.torch_dtype,
147
+ trust_remote_code=True,
148
+ )
149
+ print(f'Using model backend: {backend_name}')
150
+ model = model.cuda()
151
+ model.eval()
152
+
153
+ dataset = PoseHICODetDataset(
154
+ data_path=args.data_path,
155
+ multimodal_cfg=dict(
156
+ image_folder=os.path.join(args.data_path, 'Images/images/train2015'),
157
+ data_augmentation=False,
158
+ image_size=336,
159
+ ),
160
+ annotation_path=args.annotation_path,
161
+ max_samples=args.max_samples,
162
+ )
163
+ worker(model, processor, dataset, args)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
169
+ parser.add_argument("--data-path", type=str, default="")
170
+ parser.add_argument("--annotation-path", type=str, default="./outputs/merged_labels.json")
171
+ parser.add_argument("--output-dir", type=str, default="")
172
+ parser.add_argument("--batch-size", type=int, default=8)
173
+ parser.add_argument("--max-new-tokens", type=int, default=512)
174
+ parser.add_argument("--max-samples", type=int, default=0)
175
+ parser.add_argument("--model-backend", type=str, default="auto")
176
+ parser.add_argument("--torch-dtype", type=str, default="bfloat16")
177
+ args = parser.parse_args()
178
+
179
+ eval_model(args)
tools/merge_json_outputs.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import glob
4
+ import argparse
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--input-dir", type=str, required=True)
10
+ parser.add_argument("--pattern", type=str, required=True)
11
+ parser.add_argument("--output-path", type=str, required=True)
12
+ args = parser.parse_args()
13
+
14
+ input_pattern = os.path.join(args.input_dir, args.pattern)
15
+ input_paths = sorted(glob.glob(input_pattern))
16
+ if not input_paths:
17
+ raise FileNotFoundError(f"No files matched pattern: {input_pattern}")
18
+
19
+ merged = []
20
+ for path in input_paths:
21
+ with open(path, "r", encoding="utf-8") as f:
22
+ data = json.load(f)
23
+ if not isinstance(data, list):
24
+ raise ValueError(f"{path} is not a JSON list, got {type(data)}")
25
+ merged.extend(data)
26
+
27
+ output_dir = os.path.dirname(args.output_path)
28
+ if output_dir:
29
+ os.makedirs(output_dir, exist_ok=True)
30
+
31
+ with open(args.output_path, "w", encoding="utf-8") as f:
32
+ json.dump(merged, f, ensure_ascii=False, indent=2)
33
+
34
+ print(f"Merged {len(input_paths)} files into {args.output_path}")
35
+ print(f"Total items: {len(merged)}")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ main()
tools/refine_hico.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from torch.utils.data import DataLoader
9
+
10
+ from data.dataset_for_clean_descrip import PoseHICODetDataset
11
+ from data.convsersation import Conversation_For_Clean_Descrption as Conversation
12
+
13
+ from dataclasses import dataclass
14
+
15
+ from tools.vlm_backend import build_batch_tensors, decode_generated_text, load_model_and_processor
16
+
17
+ class StreamingJsonArrayWriter:
18
+ def __init__(self, output_path):
19
+ self.output_path = output_path
20
+ self.file = None
21
+ self.is_first = True
22
+
23
+ def __enter__(self):
24
+ self.file = open(self.output_path, "w", encoding="utf-8")
25
+ self.file.write("[\n")
26
+ self.file.flush()
27
+ return self
28
+
29
+ def write(self, item):
30
+ if not self.is_first:
31
+ self.file.write(",\n")
32
+ json.dump(item, self.file, ensure_ascii=False, indent=2)
33
+ self.file.flush()
34
+ self.is_first = False
35
+
36
+ def __exit__(self, exc_type, exc_val, exc_tb):
37
+ if self.file is not None:
38
+ self.file.write("\n]\n")
39
+ self.file.close()
40
+
41
+
42
+ def disable_torch_init():
43
+ """
44
+ Disable the redundant torch default initialization to accelerate model creation.
45
+ """
46
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
47
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
48
+
49
+
50
+ @dataclass
51
+ class DataCollatorForSupervisedDataset(object):
52
+ def __init__(self, processor, data_path):
53
+ self.processor = processor
54
+ self.conv = Conversation(
55
+ system='',
56
+ data_path=data_path
57
+ )
58
+
59
+ def __call__(self, data_dicts):
60
+ batch_prompts = []
61
+ batch_images = []
62
+ result_meta = []
63
+
64
+ for data_dict in data_dicts:
65
+ batch_images.append(data_dict['image'])
66
+ batch_prompts.append(self.conv.get_prompt(data_dict['meta']))
67
+ result_meta.append(data_dict['meta'])
68
+
69
+ batch_tensors = build_batch_tensors(
70
+ processor=self.processor,
71
+ prompts=batch_prompts,
72
+ images=batch_images,
73
+ system_prompt=self.conv.system,
74
+ )
75
+ return batch_tensors, result_meta
76
+
77
+
78
+ @torch.no_grad()
79
+ def worker(model, processor, dataset, args):
80
+ rank = int(os.environ["LOCAL_RANK"])
81
+ world_size = int(os.environ["WORLD_SIZE"])
82
+ indices = list(range(rank, len(dataset), world_size))
83
+ print("==>" + " Worker {} Started, responsible for {} images".format(rank, len(indices)))
84
+
85
+ sub_dataset = torch.utils.data.Subset(dataset, indices)
86
+ data_loader = DataLoader(
87
+ sub_dataset,
88
+ batch_size=args.batch_size,
89
+ shuffle=False,
90
+ num_workers=0,
91
+ collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path)
92
+ )
93
+ output_path = os.path.join(args.output_dir, f'refine_labels_{rank}.json')
94
+
95
+ with StreamingJsonArrayWriter(output_path) as writer:
96
+ for batch_tensors, result_meta in tqdm(data_loader):
97
+ input_ids = batch_tensors['input_ids'].cuda()
98
+ batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)}
99
+ with torch.inference_mode():
100
+ output_dict = model.generate(
101
+ do_sample=False,
102
+ output_scores=True,
103
+ return_dict_in_generate=True,
104
+ max_new_tokens=args.max_new_tokens,
105
+ output_logits=True,
106
+ **batch_tensors,
107
+ )
108
+ output_ids = output_dict['sequences']
109
+
110
+ for input_id, output_id, meta in zip(input_ids, output_ids, result_meta):
111
+ input_token_len = input_id.shape[0]
112
+ n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item()
113
+ if n_diff_input_output > 0:
114
+ print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
115
+ output = decode_generated_text(processor, output_id, input_id)
116
+ meta['refined_description'] = output
117
+ writer.write(meta)
118
+
119
+
120
+ def eval_model(args):
121
+ dist.init_process_group(backend='nccl')
122
+ rank = int(os.environ["LOCAL_RANK"])
123
+ world_size = int(os.environ["WORLD_SIZE"])
124
+
125
+ print('Init process group: world_size: {}, rank: {}'.format(world_size, rank))
126
+ torch.cuda.set_device(rank)
127
+
128
+ disable_torch_init()
129
+ backend_name, model, processor = load_model_and_processor(
130
+ model_path=args.model_path,
131
+ backend=args.model_backend,
132
+ torch_dtype=args.torch_dtype,
133
+ trust_remote_code=True,
134
+ )
135
+ print(f'Using model backend: {backend_name}')
136
+ model = model.cuda()
137
+ model.eval()
138
+
139
+ dataset = PoseHICODetDataset(
140
+ data_path=args.data_path,
141
+ multimodal_cfg=dict(
142
+ image_folder=os.path.join(args.data_path, 'Images/images/train2015'),
143
+ data_augmentation=False,
144
+ image_size=336,
145
+ ),
146
+ annotation_path=args.annotation_path,
147
+ max_samples=args.max_samples,
148
+ )
149
+ worker(model, processor, dataset, args)
150
+
151
+
152
+ if __name__ == "__main__":
153
+ parser = argparse.ArgumentParser()
154
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
155
+ parser.add_argument("--data-path", type=str, default="")
156
+ parser.add_argument("--annotation-path", type=str, default="./outputs/merged_labels.json")
157
+ parser.add_argument("--output-dir", type=str, default="")
158
+ parser.add_argument("--batch-size", type=int, default=8)
159
+ parser.add_argument("--max-new-tokens", type=int, default=512)
160
+ parser.add_argument("--max-samples", type=int, default=0)
161
+ parser.add_argument("--model-backend", type=str, default="auto")
162
+ parser.add_argument("--torch-dtype", type=str, default="bfloat16")
163
+ args = parser.parse_args()
164
+
165
+ eval_model(args)
tools/vlm_backend.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+
4
+ def _get_transformers():
5
+ import transformers
6
+ return transformers
7
+
8
+
9
+ def resolve_torch_dtype(dtype_name):
10
+ import torch
11
+
12
+ if dtype_name == "auto":
13
+ return "auto"
14
+ if not hasattr(torch, dtype_name):
15
+ raise ValueError(f"Unsupported torch dtype: {dtype_name}")
16
+ return getattr(torch, dtype_name)
17
+
18
+
19
+ def infer_model_backend(model_path, backend="auto", trust_remote_code=True):
20
+ if backend != "auto":
21
+ return backend
22
+
23
+ transformers = _get_transformers()
24
+ config = transformers.AutoConfig.from_pretrained(
25
+ model_path,
26
+ trust_remote_code=trust_remote_code
27
+ )
28
+ architectures = [arch.lower() for arch in (getattr(config, "architectures", None) or [])]
29
+ model_type = str(getattr(config, "model_type", "")).lower()
30
+ arch_text = " ".join(architectures)
31
+
32
+ if "qwen3vlmoe" in arch_text or ("qwen" in model_type and "moe" in arch_text):
33
+ return "qwen3_vl_moe"
34
+ if "qwen3vl" in arch_text or ("qwen" in model_type and "vl" in model_type):
35
+ return "qwen3_vl"
36
+ if "llava" in arch_text or "llava" in model_type:
37
+ return "llava"
38
+ if "deepseek" in arch_text or "deepseek" in model_type or "janus" in arch_text or "janus" in model_type:
39
+ return "deepseek_vl"
40
+ return "hf_vision2seq"
41
+
42
+
43
+ def load_model_and_processor(
44
+ model_path,
45
+ backend="auto",
46
+ torch_dtype="bfloat16",
47
+ trust_remote_code=True,
48
+ ):
49
+ transformers = _get_transformers()
50
+ backend = infer_model_backend(
51
+ model_path=model_path,
52
+ backend=backend,
53
+ trust_remote_code=trust_remote_code,
54
+ )
55
+ dtype = resolve_torch_dtype(torch_dtype)
56
+
57
+ if backend == "qwen3_vl":
58
+ model_cls = transformers.Qwen3VLForConditionalGeneration
59
+ elif backend == "qwen3_vl_moe":
60
+ model_cls = transformers.Qwen3VLMoeForConditionalGeneration
61
+ elif backend == "llava":
62
+ model_cls = getattr(transformers, "LlavaForConditionalGeneration", None)
63
+ if model_cls is None:
64
+ model_cls = transformers.AutoModelForVision2Seq
65
+ elif backend == "deepseek_vl":
66
+ # DeepSeek multimodal checkpoints often rely on trust_remote_code and may expose
67
+ # custom causal-LM style classes instead of Vision2Seq classes.
68
+ model_cls = transformers.AutoModelForCausalLM
69
+ elif backend == "hf_vision2seq":
70
+ model_cls = transformers.AutoModelForVision2Seq
71
+ elif backend == "hf_causal_vlm":
72
+ model_cls = transformers.AutoModelForCausalLM
73
+ else:
74
+ raise ValueError(f"Unsupported model backend: {backend}")
75
+
76
+ model = model_cls.from_pretrained(
77
+ model_path,
78
+ torch_dtype=dtype,
79
+ trust_remote_code=trust_remote_code,
80
+ )
81
+ processor = transformers.AutoProcessor.from_pretrained(
82
+ model_path,
83
+ trust_remote_code=trust_remote_code,
84
+ )
85
+ _configure_processor(processor)
86
+ return backend, model, processor
87
+
88
+
89
+ def _configure_processor(processor):
90
+ tokenizer = getattr(processor, "tokenizer", None)
91
+ if tokenizer is None:
92
+ return
93
+ if getattr(tokenizer, "padding_side", None) is not None:
94
+ tokenizer.padding_side = "left"
95
+ if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None) is not None:
96
+ tokenizer.pad_token = tokenizer.eos_token
97
+
98
+
99
+ def build_batch_tensors(processor, prompts: List[str], images, system_prompt=""):
100
+ messages = []
101
+ for prompt in prompts:
102
+ messages.append([
103
+ {
104
+ "role": "system",
105
+ "content": [
106
+ {"type": "text", "text": system_prompt},
107
+ ],
108
+ },
109
+ {
110
+ "role": "user",
111
+ "content": [
112
+ {"type": "image"},
113
+ {"type": "text", "text": prompt},
114
+ ],
115
+ },
116
+ ])
117
+
118
+ rendered_prompts = []
119
+ if hasattr(processor, "apply_chat_template"):
120
+ rendered_prompts = [
121
+ processor.apply_chat_template(
122
+ message,
123
+ tokenize=False,
124
+ add_generation_prompt=True,
125
+ )
126
+ for message in messages
127
+ ]
128
+ else:
129
+ tokenizer = getattr(processor, "tokenizer", None)
130
+ if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
131
+ rendered_prompts = [
132
+ tokenizer.apply_chat_template(
133
+ message,
134
+ tokenize=False,
135
+ add_generation_prompt=True,
136
+ )
137
+ for message in messages
138
+ ]
139
+ else:
140
+ rendered_prompts = prompts
141
+
142
+ try:
143
+ return processor(
144
+ text=rendered_prompts,
145
+ images=images,
146
+ return_tensors="pt",
147
+ padding=True,
148
+ )
149
+ except TypeError:
150
+ return processor(
151
+ text=rendered_prompts,
152
+ images=images,
153
+ return_tensors="pt",
154
+ )
155
+
156
+
157
+ def decode_generated_text(processor, output_ids, prompt_input_ids):
158
+ tokenizer = getattr(processor, "tokenizer", processor)
159
+ input_token_len = prompt_input_ids.shape[0]
160
+ return tokenizer.batch_decode(
161
+ output_ids[input_token_len:].unsqueeze(0),
162
+ skip_special_tokens=True
163
+ )[0]