Update modeling_bertchunker.py
Browse files- modeling_bertchunker.py +83 -9
modeling_bertchunker.py
CHANGED
@@ -3,8 +3,6 @@ from torch import nn
|
|
3 |
from transformers.models.bert.configuration_bert import BertConfig
|
4 |
from transformers.models.bert.modeling_bert import BertModel
|
5 |
import torch
|
6 |
-
import safetensors
|
7 |
-
from transformers import AutoConfig,AutoTokenizer
|
8 |
class BertChunker(PreTrainedModel):
|
9 |
|
10 |
config_class = BertConfig
|
@@ -14,7 +12,7 @@ class BertChunker(PreTrainedModel):
|
|
14 |
|
15 |
self.model = BertModel(config)
|
16 |
self.chunklayer = nn.Linear(384, 2)
|
17 |
-
|
18 |
def forward(self, input_ids=None, attention_mask=None,labels=None, **kwargs):
|
19 |
model_output = self.model(
|
20 |
input_ids=input_ids, attention_mask=attention_mask, **kwargs
|
@@ -35,11 +33,11 @@ class BertChunker(PreTrainedModel):
|
|
35 |
labels = labels.to(labels.device)
|
36 |
loss = loss_fct(logits, labels)
|
37 |
model_output["loss"]=loss
|
38 |
-
|
39 |
return model_output
|
40 |
-
|
41 |
-
def chunk_text(self, text:str, tokenizer,threshold=0)->list[str]:
|
42 |
|
|
|
|
|
43 |
MAX_TOKENS=255
|
44 |
tokens=tokenizer(text, return_tensors="pt",truncation=False)
|
45 |
input_ids=tokens['input_ids']
|
@@ -60,8 +58,8 @@ class BertChunker(PreTrainedModel):
|
|
60 |
ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
|
61 |
|
62 |
ids=ids.to(self.device)
|
63 |
-
|
64 |
-
output=self(input_ids=ids,attention_mask=
|
65 |
logits = output['logits'][:, 1:-1,:]
|
66 |
is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
|
67 |
greater_rows_indices = torch.where(is_left_greater)[1].tolist()
|
@@ -69,7 +67,6 @@ class BertChunker(PreTrainedModel):
|
|
69 |
# null or not
|
70 |
if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
|
71 |
|
72 |
-
|
73 |
split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
|
74 |
|
75 |
split_str_poses += split_str_pos
|
@@ -82,3 +79,80 @@ class BertChunker(PreTrainedModel):
|
|
82 |
|
83 |
substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
|
84 |
return substrings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from transformers.models.bert.configuration_bert import BertConfig
|
4 |
from transformers.models.bert.modeling_bert import BertModel
|
5 |
import torch
|
|
|
|
|
6 |
class BertChunker(PreTrainedModel):
|
7 |
|
8 |
config_class = BertConfig
|
|
|
12 |
|
13 |
self.model = BertModel(config)
|
14 |
self.chunklayer = nn.Linear(384, 2)
|
15 |
+
|
16 |
def forward(self, input_ids=None, attention_mask=None,labels=None, **kwargs):
|
17 |
model_output = self.model(
|
18 |
input_ids=input_ids, attention_mask=attention_mask, **kwargs
|
|
|
33 |
labels = labels.to(labels.device)
|
34 |
loss = loss_fct(logits, labels)
|
35 |
model_output["loss"]=loss
|
36 |
+
|
37 |
return model_output
|
|
|
|
|
38 |
|
39 |
+
def chunk_text(self, text:str, tokenizer,threshold=0)->list[str]:
|
40 |
+
# slide context window
|
41 |
MAX_TOKENS=255
|
42 |
tokens=tokenizer(text, return_tensors="pt",truncation=False)
|
43 |
input_ids=tokens['input_ids']
|
|
|
58 |
ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
|
59 |
|
60 |
ids=ids.to(self.device)
|
61 |
+
|
62 |
+
output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1]))
|
63 |
logits = output['logits'][:, 1:-1,:]
|
64 |
is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
|
65 |
greater_rows_indices = torch.where(is_left_greater)[1].tolist()
|
|
|
67 |
# null or not
|
68 |
if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
|
69 |
|
|
|
70 |
split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
|
71 |
|
72 |
split_str_poses += split_str_pos
|
|
|
79 |
|
80 |
substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
|
81 |
return substrings
|
82 |
+
|
83 |
+
def chunk_text_fast(
|
84 |
+
self, text: str, tokenizer, batchsize=20, threshold=0
|
85 |
+
) -> list[str]:
|
86 |
+
# chunk the text faster with a fixed context window, batchsize is the number of windows run per batch.
|
87 |
+
self.eval()
|
88 |
+
|
89 |
+
split_str_poses=[]
|
90 |
+
MAX_TOKENS = 255
|
91 |
+
USEFUL_TOKENS = MAX_TOKENS - 2 # delete cls and sep
|
92 |
+
tokens = tokenizer(text, return_tensors="pt", truncation=False)
|
93 |
+
input_ids = tokens["input_ids"]
|
94 |
+
|
95 |
+
|
96 |
+
CLS = tokenizer.cls_token_id
|
97 |
+
|
98 |
+
SEP = tokenizer.sep_token_id
|
99 |
+
|
100 |
+
input_ids = input_ids[:, 1:-1].squeeze().contiguous()# delete cls and sep
|
101 |
+
|
102 |
+
token_num = input_ids.shape[0]
|
103 |
+
seq_num = input_ids.shape[0] // (USEFUL_TOKENS)
|
104 |
+
left_token_num = input_ids.shape[0] % (USEFUL_TOKENS)
|
105 |
+
|
106 |
+
if seq_num > 0:
|
107 |
+
|
108 |
+
reshaped_input_ids = input_ids[: seq_num * USEFUL_TOKENS].view( seq_num, USEFUL_TOKENS )
|
109 |
+
|
110 |
+
i = torch.arange(seq_num).unsqueeze(1)
|
111 |
+
j = torch.arange(USEFUL_TOKENS).repeat(seq_num, 1)
|
112 |
+
|
113 |
+
bias = 1 # 1 bias by cls token
|
114 |
+
position_id = i * (USEFUL_TOKENS) + j + bias
|
115 |
+
position_id = position_id.to(self.device)
|
116 |
+
reshaped_input_ids = torch.cat(
|
117 |
+
(
|
118 |
+
torch.full((reshaped_input_ids.shape[0], 1), CLS),
|
119 |
+
reshaped_input_ids,
|
120 |
+
torch.full((reshaped_input_ids.shape[0], 1), SEP),
|
121 |
+
),
|
122 |
+
1,
|
123 |
+
)
|
124 |
+
|
125 |
+
batch_num = seq_num // batchsize
|
126 |
+
left_seq_num = seq_num % batchsize
|
127 |
+
for i in range(batch_num):
|
128 |
+
batch_input = reshaped_input_ids[i : i + batchsize, :].to(self.device)
|
129 |
+
attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
|
130 |
+
output = self(input_ids=batch_input, attention_mask=attention_mask)
|
131 |
+
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
132 |
+
is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
|
133 |
+
pos = is_left_greater * position_id[i : i + batchsize, :]
|
134 |
+
pos = pos[pos>0].tolist()
|
135 |
+
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
136 |
+
if left_seq_num > 0:
|
137 |
+
batch_input = reshaped_input_ids[-left_seq_num:, :].to(self.device)
|
138 |
+
attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
|
139 |
+
output = self(input_ids=batch_input, attention_mask=attention_mask)
|
140 |
+
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
141 |
+
is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
|
142 |
+
pos = is_left_greater * position_id[-left_seq_num:, :]
|
143 |
+
pos = pos[pos>0].tolist()
|
144 |
+
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
145 |
+
|
146 |
+
if left_token_num > 0:
|
147 |
+
left_input_ids = torch.cat([torch.tensor([CLS]), input_ids[-left_token_num:], torch.tensor([SEP])])
|
148 |
+
left_input_ids = left_input_ids.unsqueeze(0).to(self.device)
|
149 |
+
attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device)
|
150 |
+
output = self(input_ids=left_input_ids, attention_mask=attention_mask)
|
151 |
+
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
152 |
+
is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
|
153 |
+
bias = token_num - (left_input_ids.shape[1] - 2) + 1
|
154 |
+
pos = (torch.where(is_left_greater)[1] + bias).tolist()
|
155 |
+
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
156 |
+
|
157 |
+
substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
|
158 |
+
return substrings
|