File size: 16,471 Bytes
3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce f68c4f8 3bdd5ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 |
# ner_module.py
import torch
import time
from typing import List, Dict, Any, Tuple
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class NERModel:
"""
A singleton class to manage the NER model loading and prediction.
Ensures the potentially large model is loaded only once.
"""
_instance = None
_model = None
_tokenizer = None
_pipeline = None
_model_name = None # Store model name used for initialization
@classmethod
def get_instance(cls, model_name: str = "Davlan/bert-base-multilingual-cased-ner-hrl"):
"""
Singleton pattern: Get the existing instance or create a new one.
Uses the specified model_name only during the first initialization.
"""
if cls._instance is None:
logger.info(f"Creating new NERModel instance with model: {model_name}")
cls._instance = cls(model_name)
elif cls._model_name != model_name:
logger.warning(f"NERModel already initialized with {cls._model_name}. Ignoring new model name {model_name}.")
return cls._instance
def __init__(self, model_name: str):
"""
Initialize the model, tokenizer, and pipeline.
Private constructor - use get_instance() instead.
"""
if NERModel._instance is not None:
raise Exception("This class is a singleton! Use get_instance() to get the object.")
else:
self.model_name = model_name
NERModel._model_name = model_name # Store the model name
self._load_model()
NERModel._instance = self # Assign the instance here
def _load_model(self):
"""Load the NER model and tokenizer from Hugging Face."""
logger.info(f"Loading model: {self.model_name}")
start_time = time.time()
try:
# Load tokenizer and model
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._model = AutoModelForTokenClassification.from_pretrained(self.model_name)
# Check if the model is a PyTorch model for potential optimizations
if isinstance(self._model, torch.nn.Module):
self._model.eval() # Set model to evaluation mode (important for inference)
# Create the NER pipeline
self._pipeline = pipeline(
"ner",
model=self._model,
tokenizer=self._tokenizer,
# grouped_entities=True # Uncomment if you want to use pipeline's built-in grouping
)
load_time = time.time() - start_time
logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.")
except Exception as e:
logger.error(f"Error loading model {self.model_name}: {e}")
# Clean up partial loads if necessary
self._tokenizer = None
self._model = None
self._pipeline = None
# Re-raise the exception to signal failure
raise
def predict(self, text: str) -> List[Dict[str, Any]]:
"""
Run NER prediction on the input text using the loaded pipeline.
Args:
text: The input string to perform NER on.
Returns:
A list of dictionaries, where each dictionary represents an entity
identified by the pipeline.
"""
if self._pipeline is None:
logger.error("NER pipeline is not initialized. Cannot predict.")
return [] # Return empty list or raise an error
if not text or not isinstance(text, str):
logger.warning("Prediction called with empty or invalid text.")
return []
logger.debug(f"Running prediction on text: '{text[:100]}...'") # Log snippet
try:
# The pipeline handles tokenization and prediction
results = self._pipeline(text)
logger.debug(f"Prediction results: {results}")
return results
except Exception as e:
logger.error(f"Error during NER prediction: {e}")
return [] # Return empty list on error
class TextProcessor:
"""
Provides static methods for processing text, specifically for NER tasks,
including combining subword entities and handling large texts via chunking.
"""
@staticmethod
def combine_entities(ner_results: List[Dict[str, Any]], original_text: str) -> List[Dict[str, Any]]:
"""
Combine entities that might be split into subword tokens (B-TAG, I-TAG).
This method assumes the pipeline did *not* use grouped_entities=True.
Args:
ner_results: The raw output from the NER pipeline (list of token dictionaries).
original_text: The original text input to extract entity words accurately.
Returns:
A list of dictionaries, each representing a combined entity with
'entity_type', 'start', 'end', 'score', and 'word'.
"""
if not ner_results:
return []
combined_entities = []
current_entity = None
for token in ner_results:
# Basic validation of token structure
if not all(k in token for k in ['entity', 'start', 'end', 'score']):
logger.warning(f"Skipping malformed token: {token}")
continue
# Skip 'O' tags (Outside any entity)
if token['entity'] == 'O':
# If we were tracking an entity, finalize it before moving on
if current_entity:
combined_entities.append(current_entity)
current_entity = None
continue
# Extract entity type (e.g., 'PER', 'LOC') removing 'B-' or 'I-'
entity_tag = token['entity']
if entity_tag.startswith('B-') or entity_tag.startswith('I-'):
entity_type = entity_tag[2:]
else:
# Handle cases where the tag might not have B-/I- prefix (less common)
entity_type = entity_tag
# Start of a new entity ('B-') or continuation of a different entity type
if entity_tag.startswith('B-') or (entity_tag.startswith('I-') and (not current_entity or current_entity['entity_type'] != entity_type)):
# Finalize the previous entity if it exists
if current_entity:
combined_entities.append(current_entity)
# Start the new entity
current_entity = {
'entity_type': entity_type,
'start': token['start'],
'end': token['end'],
'score': float(token['score']),
'token_count': 1 # Keep track of tokens for averaging score
}
# Continuation of the current entity ('I-' and matching type)
elif entity_tag.startswith('I-') and current_entity and current_entity['entity_type'] == entity_type:
# Extend the end position
current_entity['end'] = token['end']
# Update the score (e.g., average)
current_entity['score'] = (current_entity['score'] * current_entity['token_count'] + float(token['score'])) / (current_entity['token_count'] + 1)
current_entity['token_count'] += 1
# Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
else:
logger.warning(f"Encountered unexpected token sequence at: {token}. Starting new entity.")
if current_entity:
combined_entities.append(current_entity)
# Try to create a new entity from this token
current_entity = {
'entity_type': entity_type,
'start': token['start'],
'end': token['end'],
'score': float(token['score']),
'token_count': 1
}
# Add the last tracked entity if it exists
if current_entity:
combined_entities.append(current_entity)
# Extract the actual text 'word' for each combined entity
for entity in combined_entities:
try:
# Ensure indices are valid
start = max(0, min(entity['start'], len(original_text)))
end = max(start, min(entity['end'], len(original_text)))
entity['word'] = original_text[start:end].strip()
# Remove internal helper key
if 'token_count' in entity:
del entity['token_count']
except Exception as e:
logger.error(f"Error extracting word for entity: {entity}, error: {e}")
entity['word'] = "[Error extracting word]"
# Sort entities by start position
combined_entities.sort(key=lambda x: x['start'])
logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
return combined_entities
@staticmethod
def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]:
"""
Process large text by splitting it into overlapping chunks, running NER
on each chunk, and then combining the results intelligently.
Args:
text: The large input text string.
model: The initialized NERModel instance.
chunk_size: The maximum size of each text chunk.
overlap: The number of characters to overlap between consecutive chunks.
Returns:
A list of combined entity dictionaries for the entire text.
"""
if not text:
return []
# Use tokenizer max length if available and smaller than chunk_size
if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'):
tokenizer_max_len = model._tokenizer.model_max_length
if chunk_size > tokenizer_max_len:
logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.")
chunk_size = tokenizer_max_len
# Ensure overlap is reasonable compared to chunk size
if overlap >= chunk_size // 2:
logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
overlap = chunk_size // 4
logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
logger.info(f"Split text into {len(chunks)} chunks.")
all_raw_results = []
total_processing_time = 0
for i, (chunk_text, start_pos) in enumerate(chunks):
logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})")
start_time = time.time()
# Get raw predictions for the current chunk
raw_results_chunk = model.predict(chunk_text)
chunk_processing_time = time.time() - start_time
total_processing_time += chunk_processing_time
logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
# Adjust entity positions relative to the original text
for result in raw_results_chunk:
# Check if 'start' and 'end' exist before adjusting
if 'start' in result and 'end' in result:
result['start'] += start_pos
result['end'] += start_pos
else:
logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
all_raw_results.extend(raw_results_chunk)
logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
# Combine entities from all chunks
combined_entities = TextProcessor.combine_entities(all_raw_results, text)
# Deduplicate entities based on overlapping positions
# Two entities are considered duplicates if they have the same type and
# overlap by more than 50% of the shorter entity's length
unique_entities = []
for entity in combined_entities:
is_duplicate = False
# Calculate entity length for overlap comparison
entity_length = entity['end'] - entity['start']
for existing in unique_entities:
if existing['entity_type'] == entity['entity_type']:
# Check for significant overlap
overlap_start = max(entity['start'], existing['start'])
overlap_end = min(entity['end'], existing['end'])
if overlap_start < overlap_end: # They overlap
overlap_length = overlap_end - overlap_start
shorter_length = min(entity_length, existing['end'] - existing['start'])
# If overlap is significant (>50% of shorter entity)
if overlap_length > 0.5 * shorter_length:
is_duplicate = True
# Keep the one with higher score
if entity['score'] > existing['score']:
# Replace the existing entity with this one
unique_entities.remove(existing)
is_duplicate = False
break
if not is_duplicate:
unique_entities.append(entity)
logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
return unique_entities
@staticmethod
def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
"""
Split text into potentially overlapping chunks, trying to respect word boundaries.
Args:
text: The input text string.
chunk_size: The target maximum size of each chunk.
overlap: The desired overlap between consecutive chunks.
Returns:
A list of tuples, where each tuple contains (chunk_text, start_position_in_original_text).
"""
if not text:
return []
if chunk_size <= overlap:
raise ValueError("chunk_size must be greater than overlap")
if chunk_size <= 0:
raise ValueError("chunk_size must be positive")
chunks = []
start = 0
text_len = len(text)
while start < text_len:
# Determine the ideal end position
end = min(start + chunk_size, text_len)
# If we're at the end of the text, just use what's left
if end >= text_len:
chunks.append((text[start:], start))
break
# Try to find a suitable split point (whitespace) to ensure we don't cut words
split_pos = -1
# Search backwards from end to find a whitespace
for i in range(end, max(start, end - overlap) - 1, -1):
if i < text_len and text[i].isspace():
split_pos = i + 1 # Position after the space
break
# If no good split found, just use the calculated end
if split_pos == -1 or split_pos <= start:
actual_end = end
else:
actual_end = split_pos
# Add the chunk
chunks.append((text[start:actual_end], start))
# Calculate next start position, ensuring we make progress
next_start = start + (actual_end - start - overlap)
if next_start <= start:
next_start = start + 1
start = next_start
return chunks
|