import torch from format_entity import format_entities from transformers import DistilBertForTokenClassification, DistilBertTokenizer DRIVE_BASE_PATH = "model/" model_path = f"{DRIVE_BASE_PATH}" model = DistilBertForTokenClassification.from_pretrained(model_path) tokenizer = DistilBertTokenizer.from_pretrained(model_path) def predict_ner(input_text, max_length=128): # Split the input text into chunks chunks = [input_text[i:i + max_length] for i in range(0, len(input_text), max_length)] # Store the results for each chunk entities_all = [] for chunk in chunks: # Tokenize the chunk inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_length) # Make predictions with torch.no_grad(): outputs = model(**inputs) # Process the NER results for this chunk labels = outputs.logits.argmax(dim=2) predicted_labels = [model.config.id2label[label_id] for label_id in labels[0].tolist()] tokenized_text = tokenizer.tokenize(tokenizer.decode(inputs["input_ids"][0])) token_label_pairs_chunk = [ (token, label) for token, label in zip(tokenized_text, predicted_labels) if token not in ["[SEP]", "[CLS]"] ] # Format the entities for this chunk entities_chunk = [ (pair["text"], pair["label"]) for pair in format_entities( [pair[0] for pair in token_label_pairs_chunk], [pair[1] for pair in token_label_pairs_chunk] ) ] entities_chunk = [(entity[0].replace(" ##","").replace("##",""), entity[1]) for entity in entities_chunk] entities_all.extend(entities_chunk) return entities_all