File size: 5,637 Bytes
3bbe20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import re
import sys

_SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__))
_SUBMIT_ROOT = os.path.dirname(os.path.dirname(_SCRIPTS_DIR))
if _SCRIPTS_DIR not in sys.path:
    sys.path.insert(0, _SCRIPTS_DIR)
if _SUBMIT_ROOT not in sys.path:
    sys.path.insert(0, _SUBMIT_ROOT)

_DIRECTION_PATTERN = re.compile(
    r"situated to the (left of|right of|in front of|behind|below|above) the",
    re.IGNORECASE,
)

_SKIP_DIRECTIONS = {"below", "above"}

_DIRECTION_CANONICAL = {
    "left of":     "left",
    "right of":    "right",
    "in front of": "front",
    "behind":      "behind",
}

_DIRECTION_TO_PHRASE = {
    "left":        "left of",
    "right":       "right of",
    "in_front_of": "in front of",
    "behind":      "behind",
}

_FACING_PROMPT_TEMPLATE = (
    "In the image, there is a {object_name}. "
    "Which direction is the {object_name} facing from the camera's perspective? "
    "Choose exactly one and output only that choice: "
    "left, right, toward the camera, away from the camera."
)

_OBJECT_EXTRACTION_SYSTEM = (
    "From the following sentence, identify the target object that is being acted upon, "
    "placed, or referred to as the main subject of interest. "
    "Output only the object name, nothing else."
)


def should_remap(question):
    match = _DIRECTION_PATTERN.search(question)
    if match is None:
        return False, None, None
    phrase = match.group(1).lower()
    if phrase in _SKIP_DIRECTIONS:
        return False, None, None
    canonical = _DIRECTION_CANONICAL.get(phrase)
    return True, canonical, phrase


def _parse_facing_direction(raw_output):
    text = raw_output.lower()
    if "away from" in text or "backward" in text:
        return "facing_away_from_camera"
    if "toward the camera" in text or "towards the camera" in text:
        return "facing_toward_camera"
    if "forward" in text:
        return "facing_toward_camera"
    has_left = "left" in text
    has_right = "right" in text
    if has_left and not has_right:
        return "facing_left"
    if has_right and not has_left:
        return "facing_right"
    if "toward" in text or "towards" in text:
        return "facing_toward_camera"
    return None


def _map_direction(orig_canonical, facing_key, direction_map):
    if orig_canonical is None or facing_key is None:
        return None
    facing_entry = direction_map.get(facing_key, {})
    return facing_entry.get(orig_canonical)


def _replace_direction_in_question(question, orig_phrase, new_canonical):
    new_phrase = _DIRECTION_TO_PHRASE.get(new_canonical)
    if new_phrase is None:
        return question
    old_pattern = f"to the {orig_phrase} the"
    new_pattern = f"to the {new_phrase} the"
    return question.replace(old_pattern, new_pattern)


def _extract_object_from_question(question, clf_kwargs):
    m = re.search(r"there is a (.+?)\.", question, re.IGNORECASE)
    if m:
        return m.group(1).strip()

    import torch
    from lm_classifier import _apply_chat_template

    first_sentence = (question.split(".")[0] + ".") if "." in question else question

    clf_model = clf_kwargs["model"]
    clf_tokenizer = clf_kwargs["tokenizer"]
    first_device = next(clf_model.parameters()).device

    messages = [
        {"role": "system", "content": _OBJECT_EXTRACTION_SYSTEM},
        {"role": "user", "content": f"Sentences: {first_sentence}"},
    ]
    text = _apply_chat_template(clf_tokenizer, messages, enable_thinking=False)
    inputs = clf_tokenizer([text], return_tensors="pt").to(first_device)

    with torch.no_grad():
        generated_ids = clf_model.generate(**inputs, max_new_tokens=16, do_sample=False)
    trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)]
    raw = clf_tokenizer.batch_decode(
        trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0].strip()

    if "</think>" in raw:
        raw = raw.split("</think>", 1)[1].strip()
    if not raw or raw == "[]":
        raw = "object"
    return raw


def run_context_with_remap(question, image_path, depth_path, model_kwargs, clf_kwargs, direction_map):
    import torch
    from robobrain_runner import run_robobrain
    from evaluation import _extract_first_point

    do_remap, orig_dir_canonical, orig_dir_phrase = should_remap(question)
    if not do_remap:
        return run_robobrain(question, image_path, depth_path, model_kwargs, LM_classify="context")

    answer1 = run_robobrain(question, image_path, depth_path, model_kwargs, LM_classify="context")
    coord1_tuple, _ = _extract_first_point(answer1)
    coord1_str = f"({coord1_tuple[0]}, {coord1_tuple[1]})" if coord1_tuple else ""

    object_name = _extract_object_from_question(question, clf_kwargs)
    torch.cuda.empty_cache()

    obj_label = object_name.strip() or "object"
    dir_prompt = _FACING_PROMPT_TEMPLATE.format(object_name=obj_label)
    dir_answer = run_robobrain(dir_prompt, image_path, depth_path, model_kwargs, add_think_override=False)
    facing_key = _parse_facing_direction(dir_answer)

    new_dir = _map_direction(orig_dir_canonical, facing_key, direction_map)
    if new_dir is None:
        return coord1_str or answer1

    new_question = _replace_direction_in_question(question, orig_dir_phrase, new_dir)

    answer2 = run_robobrain(new_question, image_path, depth_path, model_kwargs, LM_classify="context")
    coord2_tuple, _ = _extract_first_point(answer2)
    coord2_str = f"({coord2_tuple[0]}, {coord2_tuple[1]})" if coord2_tuple else ""

    parts = [p for p in [coord1_str, coord2_str] if p]
    return " ".join(parts) if parts else (answer1 or answer2)