Chris4K's picture
Update ner_tool.py
28094fc
# Updated NamedEntityRecognitionTool in ner_tool.py
from transformers import pipeline
from transformers import Tool
class NamedEntityRecognitionTool(Tool):
name = "ner_tool"
description = "Identifies and labels various entities in a given text."
inputs = ["text"]
outputs = ["text"]
def __call__(self, text: str):
# Initialize the named entity recognition pipeline
ner_analyzer = pipeline("ner")
# Perform named entity recognition on the input text
entities = ner_analyzer(text)
# Prepare a list to store word-level entities
word_entities = []
# Initialize variables to track the current word and its label
current_word = ""
current_label = None
for entity in entities:
label = entity.get("entity", "UNKNOWN")
word = entity.get("word", "")
start = entity.get("start", -1)
end = entity.get("end", -1)
# Extract the complete entity text
entity_text = text[start:end].strip()
# Check for multi-token entities
if "##" in word:
# Concatenate sub-tokens to form the complete word
current_word += entity_text
current_label = label
else:
# If it's the first token of a new word, add the previous word to the list
if current_word:
word_entities.append({"word": current_word, "label": current_label, "entity_text": current_word})
current_word = ""
current_label = None
# Add the current token as a new word
word_entities.append({"word": word, "label": label, "entity_text": entity_text})
# Check for any remaining word
if current_word:
word_entities.append({"word": current_word, "label": current_label, "entity_text": current_word})
# Print the identified word-level entities
print(f"Word-level Entities: {word_entities}")
return {"entities": word_entities} # Return a dictionary with the specified output component