Model Card for Qwen3-4B-Instruct-2507-Segmenter
This is the semantic segmenter introduced in the paper Towards Generalization of Block Attention via Automatic Segmentation and Block Distillation. It has been trained using our proposed semantic segmentation dataset SemanticSeg.
How to use
Insert the candidate cut points.
Feed the output from step 1 to the segmenter, and customize the recursion depth and threshold value if needed.
An optional implementation is shown below:
def insert_marker(txt:str, sep_pattern:str = None, cut_marker:str = "<cut {}>", marker_pos: Literal["left", "right"] = "right") -> str:
candidate_blocks = []
pre_m = None
i = 0
for m in re.finditer(sep_pattern, txt):
if marker_pos == "right":
b = txt[pre_m.end() if pre_m is not None else 0: m.end()] + cut_marker.format(i+1)
else:
b = txt[pre_m.end() if pre_m is not None else 0: m.start()] + cut_marker.format(i+1) + txt[m.start(): m.end()]
candidate_blocks.append(b)
pre_m = m
i += 1
if i==0:
candidate_blocks.append(txt)
candidate_blocks[-1] = candidate_blocks[-1] + cut_marker.format(i+1)
else:
if len(txt[pre_m.end(): ]) > 0:
candidate_blocks.append(txt[pre_m.end(): ] + cut_marker.format(i+1))
candidate_blocks[0] = cut_marker.format(0) + candidate_blocks[0]
txt_marker = "".join(candidate_blocks)
return txt_marker
# states - [bsz, q_len, ...]
def shift_value(states, cut_point_mask):
cut_pos = torch.nonzero(cut_point_mask, as_tuple=True)
shift_s = states.clone()
for i, row in enumerate(states):
col_index = cut_pos[1][cut_pos[0] == i]
pre_idx = col_index[0]
for idx in col_index[1:]:
shift_s[i, pre_idx, ...] = states[i, idx, ...]
pre_idx = idx
shift_s[i, pre_idx, ...] = torch.tensor([1, float('-inf')], device=states.device)
return shift_s
def get_cutpoint_label(input_ids: list,
cut_token_ids:List[int],
chunk_bound = None,
window_size: int = 1,
):
# chunk_bound - boundaries of the chunk
input_length = len(input_ids)
labels = [-100] * input_length
candidate_cut_points = list(locate(input_ids, lambda *args: args == tuple([cut_token_ids[0]]*window_size), window_size=window_size))
if len(cut_token_ids)>1:
candidate_cut_points.extend(list(locate(input_ids, lambda *args: args == tuple([cut_token_ids[-1]]), window_size=1)))
cut_points = []
# If there is no chunk bound, just return the mask of candidate cut points
if chunk_bound is not None:
for points in chunk_bound[:-1]:
cut_number = int(re.findall(r"\d+", points[-1])[-1])
cut_points.append(cut_number)
for i, idx in enumerate(candidate_cut_points):
if i in cut_points:
labels[idx] = 1
else:
labels[idx] = 0
return labels
def model_cut(txt, insert_pattern, model, tokenizer, return_txt=True, depth=1, threshold=[0.40]):
txt_marker = insert_marker(txt=txt, sep_pattern=insert_pattern, marker_pos="right")
input_txt = re.sub(r"<cut \d+>", "<cut>", txt_marker)
# print(input_txt)
inputs = tokenizer(input_txt, return_tensors="pt", truncation=True).to(model.device)
cut_labels = get_cutpoint_label(input_ids=inputs["input_ids"].squeeze().tolist(),
cut_token_ids=[tokenizer.additional_special_tokens_ids[-1]],
window_size=1,
)
cut_point_mask = (torch.tensor(cut_labels) != -100).to(model.device)
cut_point_mask = cut_point_mask[None,:].to(torch.bool)
model.eval()
tmp_cut_pos = [inputs["input_ids"].shape[-1]]
tmp_cut_prob = None
# The recursion begins
for d in range(depth):
pre_pos = 0
cut_pos = []
cut_prob = torch.tensor([], device=model.device)
for pos in tmp_cut_pos:
if cut_point_mask[..., pre_pos:pos+1].sum() > 2:
with torch.no_grad():
# print(inputs["input_ids"].shape)
outputs = model(input_ids=inputs["input_ids"][..., pre_pos:pos+1],
attention_mask=inputs["attention_mask"][..., pre_pos:pos+1],
)
shifted_logits = shift_value(states=outputs.logits, cut_point_mask=cut_point_mask[..., pre_pos:pos+1])
# Shift the hidden states for all the candidate cut tokens, since we use the next to predict the current.
prediction_prob = F.softmax(shifted_logits, dim=-1)
prediction = (prediction_prob[..., 1] >= threshold[d])
c_pos = torch.nonzero(prediction & cut_point_mask[..., pre_pos:pos+1], as_tuple=True)[-1]
c_pos = c_pos.sort(descending=False, stable=True, dim=-1)[0]
cut_pos.extend((c_pos + pre_pos).tolist())
c_prob = prediction_prob[..., 1]
if tmp_cut_prob is None:
cut_prob = torch.concat(tensors=[cut_prob, c_prob], dim=-1)
else:
c_prob = c_prob[:, 1:-1] if pos < inputs["input_ids"].shape[-1] else c_prob[:, 1:]
cut_prob = torch.concat(tensors=[cut_prob,
tmp_cut_prob[:, pre_pos].unsqueeze(-1),
c_prob],
dim=-1)
# print(cut_prob.shape)
else:
if tmp_cut_prob is not None:
cut_prob = torch.concat(tensors=[cut_prob, tmp_cut_prob[..., pre_pos: pos]], dim=-1)
cut_pos.append(pos)
pre_pos = pos
tmp_cut_pos = copy.deepcopy(cut_pos)
tmp_cut_prob = cut_prob.clone()
cut_prob = cut_prob[cut_point_mask].tolist()
# print(cut_point_mask.sum(), len(cut_prob))
p_degree = len(cut_pos) + 1
blocks = []
chunk_id = []
prediction_prob = []
if return_txt:
pre_c = 0
s=0
for c in cut_pos:
block_txt = tokenizer.batch_decode(inputs["input_ids"][:, pre_c : c])[0]
l = len(re.findall(r"<cut>", block_txt))
chunk_id.append("<cut {}> --- <cut {}>".format(s, s+l))
block_txt = re.sub(r"<cut>", "", block_txt)
prediction_prob.append(cut_prob[s:s+l])
if len(block_txt.strip())>0:
blocks.append(block_txt)
pre_c = c
s += l
return {"blocks":blocks,
"cut_prob": prediction_prob,
"chunk_id":"\n".join(chunk_id),
"parallel degree":p_degree,
"cut positions":cut_pos,
"threshold": threshold,
"length": inputs["input_ids"].shape[-1]}
The user can customize the recursive depth and the threshold value via the model_cut function to control the final segmentation granularity.
The segmenter is trained using a threshold value of 0.5, but we find that it can also serve well in the range of 0.2 ~ 0.5. We recommend pairing each recursion level with a threshold value.
Typical combinations:
Recursion depth 1 - threshold value [0.4] (Example: LongbenchSeg, LoCoMoSeg);
Recursion depth 2 - threshold value [0.2, 0.4] (Example: ChatQA2Seg) or [0.4, 0.4].
Note: Do remember to shift the final hidden states for the candidate cut points, because the segmenter is trained to use the next candidate point to predict the current one.
An example:
from transformers import AutoTokenizer, AutoModelForCausalLM
txt = '''
import numpy as np
from numpy.testing import assert_array_equal, assert_array_almost_equal
import scipy.stats.distributions as distrs
from scipy.stats.kde import gaussian_kde
from scipy.integrate import quad
import pytest
def augment_grid(x, n_inner_points):
test_arr = [
np.linspace(x[i], x[i + 1], n_inner_points + 1, endpoint=False)
for i in np.arange(len(x) - 1)
]
test_arr.append([x[-1]])
return np.concatenate(test_arr)
def circle_fun(x, low, high):
x = np.array(x)
center = 0.5 * (high + low)
radius = 0.5 * (high - low)
res = np.zeros_like(x)
center_dist = np.abs(x - center)
is_in = center_dist <= radius
res[is_in] = np.sqrt(radius ** 2 - center_dist[is_in] ** 2)
return res
class TestCont:
"""Regression tests for `Cont` class"""
def test_init_errors(self):
def check_one_input(def_args, var):
with pytest.raises(TypeError, match=f"`{var}`.*numpy array"):
def_args[var] = {"a": None}
Cont(**def_args)
with pytest.raises(TypeError, match=f"`{var}`.*float"):
def_args[var] = ["a", "a"]
Cont(**def_args)
with pytest.raises(TypeError, match=f"`{var}`.*finite values"):
def_args[var] = [0, np.nan]
Cont(**def_args)
with pytest.raises(TypeError, match=f"`{var}`.*finite values"):
def_args[var] = [0, np.inf]
Cont(**def_args)
with pytest.raises(ValueError, match=f"`{var}`.*1d array"):
def_args[var] = [[0, 1]]
Cont(**def_args)
check_one_input({"y": [1, 1]}, "x")
check_one_input({"x": [0, 1]}, "y")
with pytest.raises(ValueError, match="[Ll]engths.*match"):
Cont([0, 1], [1, 1, 1])
with pytest.raises(ValueError, match="two"):
Cont([1], [1])
with pytest.warns(UserWarning, match="`x`.*not sorted.*`x` and `y`"):
rv = Cont([1, 0], [0, 2])
rv_ref = Cont([0, 1], [2, 0])
_test_equal_rand(rv, rv_ref)
with pytest.raises(ValueError, match="`y`.*negative"):
Cont([0, 1], [1, -1])
with pytest.raises(ValueError, match="`y`.*no positive"):
Cont([0, 1], [0, 0])
def test_init(self):
x_ref = np.array([0, 1, 2])
y_ref = np.array([0, 1, 0])
rv_ref = Cont(x_ref, y_ref)
class TestFromRVAccuracy:
"""Accuracy of `Cont.from_rv()`"""
# Output of `from_rv()` should have CDF that differs from original CDF by
# no more than `thres`
@pytest.mark.slow
@pytest.mark.parametrize(
"distr_dict,thres",
[
(DISTRIBUTIONS_COMMON, 1e-4),
(DISTRIBUTIONS_INF_DENSITY, 1e-3),
(DISTRIBUTIONS_HEAVY_TAILS, 5e-3),
],
)
def test_cdf_maxerror(self, distr_dict, thres):
test_passed = {
name: TestFromRVAccuracy.from_rv_cdf_maxerror(distr) <= thres
for name, distr in distr_dict.items()
}
assert all(test_passed.values())
class TestFromSampleAccuracy:
"""Accuracy of `Cont.from_sample()`"""
# Output of `from_sample()` should differ from original density estimate by
# no more than `thres` (with default density estimator)
@pytest.mark.slow
@pytest.mark.parametrize(
"distr_dict,thres",
[
(DISTRIBUTIONS_COMMON, 1e-4),
(DISTRIBUTIONS_INF_DENSITY, 1.5e-4),
(DISTRIBUTIONS_HEAVY_TAILS, 1e-4),
],
)
def test_close_cdf(self, distr_dict, thres):
rng = np.random.default_rng(101)
test_passed = {
name: TestFromSampleAccuracy.simulated_cdf_error(distr, rng) <= thres
for name, distr in distr_dict.items()
}
'''
model = AutoModelForCausalLM.from_pretrained("Syon-Li/Qwen3-4B-Instruct-2507-Segmenter", dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
"Syon-Li/Qwen3-4B-Instruct-2507-Segmenter",
)
results = model_cut(txt=txt, model=model, tokenizer=tokenizer, insert_pattern=r"\n{1,}", depth=1, threshold=[0.40])
print(results)
If you find this useful, please cite:
- Downloads last month
- 173
Model tree for Syon-Li/Qwen3-4B-Instruct-2507-Segmenter
Base model
Qwen/Qwen3-4B-Instruct-2507