from transformers.modeling_utils import PreTrainedModel from torch import nn from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.modeling_bert import BertModel import torch import torch.nn.functional as F class BertChunker(PreTrainedModel): config_class = BertConfig def __init__(self, config, ): super().__init__(config) self.model = BertModel(config) self.chunklayer = nn.Linear(384, 2) def forward(self, input_ids=None, attention_mask=None,labels=None, **kwargs): model_output = self.model( input_ids=input_ids, attention_mask=attention_mask, **kwargs ) token_embeddings = model_output[0] logits = self.chunklayer(token_embeddings) model_output["logits"]=logits loss = None logits = logits.contiguous() if labels: labels = labels.contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss()#用-100 # loss_fct = nn.CrossEntropyLoss(ignore_index=50257) logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism labels = labels.to(labels.device) loss = loss_fct(logits, labels) model_output["loss"]=loss return model_output def chunk_text(self, text:str, tokenizer, prob_threshold=0.5)->list[str]: # slide context window MAX_TOKENS=255 tokens=tokenizer(text, return_tensors="pt",truncation=False) input_ids=tokens['input_ids'].to(self.device) attention_mask=tokens['attention_mask'][:,0:MAX_TOKENS] attention_mask=attention_mask.to(self.device) CLS=input_ids[:,0].unsqueeze(0) SEP=input_ids[:,-1].unsqueeze(0) input_ids=input_ids[:,1:-1] self.eval() split_str_poses=[] windows_start =0 windows_end= 0 while windows_end <= input_ids.shape[1]: windows_end= windows_start + MAX_TOKENS-2 ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1) ids=ids.to(self.device) output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1],device=self.device)) logits = output['logits'][:, 1:-1,:] chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1] chunk_decision = (chunk_probabilities>prob_threshold) greater_rows_indices = torch.where(chunk_decision)[1].tolist() # null or not if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)): split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices] split_str_poses += split_str_pos windows_start = greater_rows_indices[-1] + windows_start else: windows_start = windows_end substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])] return substrings def chunk_text_fast( self, text: str, tokenizer, batchsize=20, prob_threshold=0.5 ) -> list[str]: # chunk the text faster with a fixed context window, batchsize is the number of windows run per batch. self.eval() split_str_poses=[] MAX_TOKENS = 255 USEFUL_TOKENS = MAX_TOKENS - 2 # delete cls and sep tokens = tokenizer(text, return_tensors="pt", truncation=False) input_ids = tokens["input_ids"] CLS = tokenizer.cls_token_id SEP = tokenizer.sep_token_id input_ids = input_ids[:, 1:-1].squeeze().contiguous()# delete cls and sep token_num = input_ids.shape[0] seq_num = input_ids.shape[0] // (USEFUL_TOKENS) left_token_num = input_ids.shape[0] % (USEFUL_TOKENS) if seq_num > 0: reshaped_input_ids = input_ids[: seq_num * USEFUL_TOKENS].view( seq_num, USEFUL_TOKENS ) i = torch.arange(seq_num).unsqueeze(1) j = torch.arange(USEFUL_TOKENS).repeat(seq_num, 1) bias = 1 # 1 bias by cls token position_id = i * (USEFUL_TOKENS) + j + bias position_id = position_id.to(self.device) reshaped_input_ids = torch.cat( ( torch.full((reshaped_input_ids.shape[0], 1), CLS), reshaped_input_ids, torch.full((reshaped_input_ids.shape[0], 1), SEP), ), 1, ) batch_num = seq_num // batchsize left_seq_num = seq_num % batchsize for i in range(batch_num): batch_input = reshaped_input_ids[i : i + batchsize, :].to(self.device) attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device) output = self(input_ids=batch_input, attention_mask=attention_mask) logits = output['logits'][:, 1:-1,:]#delete cls and sep # is_left_greater = ((logits[:,:, 0] + 0) < logits[:,:, 1]) chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1] chunk_decision = (chunk_probabilities>prob_threshold) pos = chunk_decision * position_id[i : i + batchsize, :] pos = pos[pos>0].tolist() split_str_poses += [tokens.token_to_chars(p).start for p in pos] if left_seq_num > 0: batch_input = reshaped_input_ids[-left_seq_num:, :].to(self.device) attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device) output = self(input_ids=batch_input, attention_mask=attention_mask) logits = output['logits'][:, 1:-1,:]#delete cls and sep chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1] chunk_decision = (chunk_probabilities>prob_threshold) pos = chunk_decision * position_id[-left_seq_num:, :] pos = pos[pos>0].tolist() split_str_poses += [tokens.token_to_chars(p).start for p in pos] if left_token_num > 0: left_input_ids = torch.cat([torch.tensor([CLS]), input_ids[-left_token_num:], torch.tensor([SEP])]) left_input_ids = left_input_ids.unsqueeze(0).to(self.device) attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device) output = self(input_ids=left_input_ids, attention_mask=attention_mask) logits = output['logits'][:, 1:-1,:]#delete cls and sep chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1] chunk_decision = (chunk_probabilities>prob_threshold) bias = token_num - (left_input_ids.shape[1] - 2) + 1 pos = (torch.where(chunk_decision)[1] + bias).tolist() split_str_poses += [tokens.token_to_chars(p).start for p in pos] substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])] return substrings