import torch from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification # Define label mappings (ensure this matches the mappings used during training) label2id = {'': 0, 'other': 2, '': 1} id2label = {v: k for k, v in label2id.items()} def prepare_input(tokens, tokenizer, max_length=128): encoding = tokenizer( tokens, is_split_into_words=True, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length, return_offsets_mapping=True ) return encoding def split_sentence(sentence): # List of special tokens to preserve special_tokens = ['', ''] # More comprehensive list of punctuation marks and symbols punctuation = ',.?!;:()[]{}""\'`@#$%^&*+=|\\/<>-—–' # Initialize result list and temporary word result = [] current_word = '' i = 0 while i < len(sentence): # Check for special tokens found_special = False for token in special_tokens: if sentence[i:].startswith(token): # Add previous word if exists if current_word: result.append(current_word) current_word = '' # Add special token result.append(token) i += len(token) found_special = True break if found_special: continue # Handle punctuation if sentence[i] in punctuation: # Add previous word if exists if current_word: result.append(current_word) current_word = '' # Add punctuation as separate token result.append(sentence[i]) # Handle spaces elif sentence[i].isspace(): if current_word: result.append(current_word) current_word = '' # Build regular words else: current_word += sentence[i] i += 1 # Add final word if exists if current_word: result.append(current_word) return result def predict(tokens, model, tokenizer, device, max_length=128): tokens = split_sentence(' '.join(tokens.lower().split())) # Prepare the input encoding = prepare_input(tokens, tokenizer, max_length=max_length) word_ids = encoding.word_ids(batch_index=0) # List of word IDs # Move tensors to device input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) # Inference with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits predictions = torch.argmax(logits, dim=-1).cpu().numpy()[0] # Decode tokens and labels tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids.cpu().numpy()[0]) labels = [id2label.get(pred, 'O') for pred in predictions] # Align tokens with original word-level tokens aligned_predictions = [] previous_word_idx = None for token, label, word_idx in zip(tokens_decoded, labels, word_ids): if word_idx is None: continue if word_idx != previous_word_idx: aligned_predictions.append((tokens[word_idx], label)) previous_word_idx = word_idx return aligned_predictions def load_token_classifier(pretrained_token_classifier_path, device): # Load tokenizer and model tokenizer = DistilBertTokenizerFast.from_pretrained(pretrained_token_classifier_path) token_classifier = DistilBertForTokenClassification.from_pretrained(pretrained_token_classifier_path) token_classifier.to(device) return token_classifier, tokenizer