Model Card for Qwen3-4B-Instruct-2507-Segmenter

This is the semantic segmenter introduced in the paper Towards Generalization of Block Attention via Automatic Segmentation and Block Distillation. It has been trained using our proposed semantic segmentation dataset SemanticSeg.

How to use

  1. Insert the candidate cut points.

  2. Feed the output from step 1 to the segmenter, and customize the recursion depth and threshold value if needed.

An optional implementation is shown below:

def insert_marker(txt:str, sep_pattern:str = None, cut_marker:str = "<cut {}>", marker_pos: Literal["left", "right"] = "right") -> str:
    candidate_blocks = []
    pre_m = None
    i = 0
    for m in re.finditer(sep_pattern, txt):
        if marker_pos == "right":
            b = txt[pre_m.end() if pre_m is not None else 0: m.end()] + cut_marker.format(i+1)
        else:
            b = txt[pre_m.end() if pre_m is not None else 0: m.start()] + cut_marker.format(i+1) + txt[m.start(): m.end()]
        candidate_blocks.append(b)
        pre_m = m
        i += 1

    if i==0:
        candidate_blocks.append(txt)
        candidate_blocks[-1] = candidate_blocks[-1] + cut_marker.format(i+1)
    else:
        if len(txt[pre_m.end(): ]) > 0:
            candidate_blocks.append(txt[pre_m.end(): ] + cut_marker.format(i+1))
    candidate_blocks[0] = cut_marker.format(0) + candidate_blocks[0]
    txt_marker = "".join(candidate_blocks)
    return txt_marker

# states - [bsz, q_len, ...]
def shift_value(states, cut_point_mask):
    cut_pos = torch.nonzero(cut_point_mask, as_tuple=True)
    shift_s = states.clone()
    for i, row in enumerate(states):
        col_index = cut_pos[1][cut_pos[0] == i]
        pre_idx = col_index[0]
        for idx in col_index[1:]:
            shift_s[i, pre_idx, ...] = states[i, idx, ...]
            pre_idx = idx
        shift_s[i, pre_idx, ...] = torch.tensor([1, float('-inf')], device=states.device)
    return shift_s


def get_cutpoint_label(input_ids: list, 
                       cut_token_ids:List[int], 
                       chunk_bound = None,
                       window_size: int = 1,
                       ):
# chunk_bound - boundaries of the chunk
    input_length = len(input_ids)
    labels = [-100] * input_length

    candidate_cut_points = list(locate(input_ids, lambda *args: args == tuple([cut_token_ids[0]]*window_size), window_size=window_size))
    if len(cut_token_ids)>1:
        candidate_cut_points.extend(list(locate(input_ids, lambda *args: args == tuple([cut_token_ids[-1]]), window_size=1)))

    cut_points = []
    # If there is no chunk bound, just return the mask of candidate cut points
    if chunk_bound is not None:
        for points in chunk_bound[:-1]:
            cut_number = int(re.findall(r"\d+", points[-1])[-1])
            cut_points.append(cut_number)

    for i, idx in enumerate(candidate_cut_points):
        if i in cut_points:
            labels[idx] = 1
        else:
            labels[idx] = 0

    return labels


def model_cut(txt, insert_pattern, model, tokenizer, return_txt=True, depth=1, threshold=[0.40]):
    txt_marker = insert_marker(txt=txt, sep_pattern=insert_pattern, marker_pos="right")
    input_txt = re.sub(r"<cut \d+>", "<cut>", txt_marker)

    # print(input_txt)
    inputs = tokenizer(input_txt, return_tensors="pt", truncation=True).to(model.device)
    cut_labels = get_cutpoint_label(input_ids=inputs["input_ids"].squeeze().tolist(),
                                    cut_token_ids=[tokenizer.additional_special_tokens_ids[-1]],
                                    window_size=1,
                                    )
    cut_point_mask = (torch.tensor(cut_labels) != -100).to(model.device)
    cut_point_mask = cut_point_mask[None,:].to(torch.bool)

    model.eval()
    tmp_cut_pos = [inputs["input_ids"].shape[-1]]
    tmp_cut_prob = None
    # The recursion begins
    for d in range(depth):
        pre_pos = 0
        cut_pos = []
        cut_prob = torch.tensor([], device=model.device)
        for pos in tmp_cut_pos:
            if cut_point_mask[..., pre_pos:pos+1].sum() > 2:
                with torch.no_grad():
                    # print(inputs["input_ids"].shape)
                    outputs = model(input_ids=inputs["input_ids"][..., pre_pos:pos+1], 
                                    attention_mask=inputs["attention_mask"][..., pre_pos:pos+1], 
                                    )
                    shifted_logits = shift_value(states=outputs.logits, cut_point_mask=cut_point_mask[..., pre_pos:pos+1])
                    # Shift the hidden states for all the candidate cut tokens, since we use the next to predict the current.

                    prediction_prob = F.softmax(shifted_logits, dim=-1)
                    prediction = (prediction_prob[..., 1] >= threshold[d])
                    c_pos = torch.nonzero(prediction & cut_point_mask[..., pre_pos:pos+1], as_tuple=True)[-1]
                    c_pos = c_pos.sort(descending=False, stable=True, dim=-1)[0]
                    cut_pos.extend((c_pos + pre_pos).tolist())

                    c_prob = prediction_prob[..., 1]
                    if tmp_cut_prob is None:
                        cut_prob = torch.concat(tensors=[cut_prob, c_prob], dim=-1)
                    else:
                        c_prob = c_prob[:, 1:-1] if pos < inputs["input_ids"].shape[-1] else c_prob[:, 1:]
                        cut_prob = torch.concat(tensors=[cut_prob, 
                                                        tmp_cut_prob[:, pre_pos].unsqueeze(-1), 
                                                        c_prob], 
                                                dim=-1)
                    # print(cut_prob.shape)
            else:
                if tmp_cut_prob is not None:
                    cut_prob = torch.concat(tensors=[cut_prob, tmp_cut_prob[..., pre_pos: pos]], dim=-1)
            cut_pos.append(pos)
            pre_pos = pos

        tmp_cut_pos = copy.deepcopy(cut_pos)
        tmp_cut_prob = cut_prob.clone()
    
    cut_prob = cut_prob[cut_point_mask].tolist()
    # print(cut_point_mask.sum(), len(cut_prob))

    p_degree = len(cut_pos) + 1
    blocks = []
    chunk_id = []
    prediction_prob = []
    if return_txt:
        pre_c = 0
        s=0
        for c in cut_pos:
            block_txt = tokenizer.batch_decode(inputs["input_ids"][:, pre_c : c])[0]
            l = len(re.findall(r"<cut>", block_txt))
            chunk_id.append("<cut {}> --- <cut {}>".format(s, s+l))
            block_txt = re.sub(r"<cut>", "", block_txt)
            prediction_prob.append(cut_prob[s:s+l])
            if len(block_txt.strip())>0:
                blocks.append(block_txt)
            pre_c = c
            s += l
    
    return {"blocks":blocks, 
            "cut_prob": prediction_prob,
            "chunk_id":"\n".join(chunk_id),
            "parallel degree":p_degree, 
            "cut positions":cut_pos, 
            "threshold": threshold,
            "length": inputs["input_ids"].shape[-1]}

The user can customize the recursive depth and the threshold value via the model_cut function to control the final segmentation granularity.

The segmenter is trained using a threshold value of 0.5, but we find that it can also serve well in the range of 0.2 ~ 0.5. We recommend pairing each recursion level with a threshold value.

Typical combinations:

Recursion depth 1 - threshold value [0.4] (Example: LongbenchSeg, LoCoMoSeg);

Recursion depth 2 - threshold value [0.2, 0.4] (Example: ChatQA2Seg) or [0.4, 0.4].

Note: Do remember to shift the final hidden states for the candidate cut points, because the segmenter is trained to use the next candidate point to predict the current one.

An example:

from transformers import AutoTokenizer, AutoModelForCausalLM

txt = '''
import numpy as np
from numpy.testing import assert_array_equal, assert_array_almost_equal
import scipy.stats.distributions as distrs
from scipy.stats.kde import gaussian_kde
from scipy.integrate import quad
import pytest


def augment_grid(x, n_inner_points):
    test_arr = [
        np.linspace(x[i], x[i + 1], n_inner_points + 1, endpoint=False)
        for i in np.arange(len(x) - 1)
    ]
    test_arr.append([x[-1]])
    return np.concatenate(test_arr)

def circle_fun(x, low, high):
    x = np.array(x)
    center = 0.5 * (high + low)
    radius = 0.5 * (high - low)

    res = np.zeros_like(x)

    center_dist = np.abs(x - center)
    is_in = center_dist <= radius
    res[is_in] = np.sqrt(radius ** 2 - center_dist[is_in] ** 2)

    return res

class TestCont:
"""Regression tests for `Cont` class"""

    def test_init_errors(self):
        def check_one_input(def_args, var):
                with pytest.raises(TypeError, match=f"`{var}`.*numpy array"):
                    def_args[var] = {"a": None}
                    Cont(**def_args)
                with pytest.raises(TypeError, match=f"`{var}`.*float"):
                    def_args[var] = ["a", "a"]
                    Cont(**def_args)
                with pytest.raises(TypeError, match=f"`{var}`.*finite values"):
                    def_args[var] = [0, np.nan]
                    Cont(**def_args)
                with pytest.raises(TypeError, match=f"`{var}`.*finite values"):
                    def_args[var] = [0, np.inf]
                    Cont(**def_args)
                with pytest.raises(ValueError, match=f"`{var}`.*1d array"):
                    def_args[var] = [[0, 1]]
                    Cont(**def_args)

            check_one_input({"y": [1, 1]}, "x")
            check_one_input({"x": [0, 1]}, "y")

            with pytest.raises(ValueError, match="[Ll]engths.*match"):
                Cont([0, 1], [1, 1, 1])

            with pytest.raises(ValueError, match="two"):
                Cont([1], [1])

            with pytest.warns(UserWarning, match="`x`.*not sorted.*`x` and `y`"):
                rv = Cont([1, 0], [0, 2])
                rv_ref = Cont([0, 1], [2, 0])
                _test_equal_rand(rv, rv_ref)

            with pytest.raises(ValueError, match="`y`.*negative"):
                Cont([0, 1], [1, -1])

            with pytest.raises(ValueError, match="`y`.*no positive"):
                Cont([0, 1], [0, 0])

    def test_init(self):
        x_ref = np.array([0, 1, 2])
        y_ref = np.array([0, 1, 0])
        rv_ref = Cont(x_ref, y_ref)

class TestFromRVAccuracy:
    """Accuracy of `Cont.from_rv()`"""

    # Output of `from_rv()` should have CDF that differs from original CDF by
    # no more than `thres`
    @pytest.mark.slow
    @pytest.mark.parametrize(
        "distr_dict,thres",
        [
            (DISTRIBUTIONS_COMMON, 1e-4),
            (DISTRIBUTIONS_INF_DENSITY, 1e-3),
            (DISTRIBUTIONS_HEAVY_TAILS, 5e-3),
        ],
    )
    def test_cdf_maxerror(self, distr_dict, thres):
        test_passed = {
            name: TestFromRVAccuracy.from_rv_cdf_maxerror(distr) <= thres
            for name, distr in distr_dict.items()
        }

        assert all(test_passed.values())

class TestFromSampleAccuracy:
    """Accuracy of `Cont.from_sample()`"""

    # Output of `from_sample()` should differ from original density estimate by   
    # no more than `thres` (with default density estimator)
    @pytest.mark.slow
    @pytest.mark.parametrize(
        "distr_dict,thres",
        [
            (DISTRIBUTIONS_COMMON, 1e-4),
            (DISTRIBUTIONS_INF_DENSITY, 1.5e-4),
            (DISTRIBUTIONS_HEAVY_TAILS, 1e-4),
        ],
    )
    def test_close_cdf(self, distr_dict, thres):
        rng = np.random.default_rng(101)
        test_passed = {
            name: TestFromSampleAccuracy.simulated_cdf_error(distr, rng) <= thres 
            for name, distr in distr_dict.items()
        }
'''

model = AutoModelForCausalLM.from_pretrained("Syon-Li/Qwen3-4B-Instruct-2507-Segmenter", dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
    "Syon-Li/Qwen3-4B-Instruct-2507-Segmenter",
)

results = model_cut(txt=txt, model=model, tokenizer=tokenizer, insert_pattern=r"\n{1,}", depth=1, threshold=[0.40])

print(results)

If you find this useful, please cite:


Downloads last month
173
Safetensors
Model size
4B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Syon-Li/Qwen3-4B-Instruct-2507-Segmenter

Finetuned
(1703)
this model

Dataset used to train Syon-Li/Qwen3-4B-Instruct-2507-Segmenter

Collection including Syon-Li/Qwen3-4B-Instruct-2507-Segmenter

Paper for Syon-Li/Qwen3-4B-Instruct-2507-Segmenter