NamedEntities / NamedEntity.py
wibberlet's picture
Update NamedEntity.py
c692027
raw
history blame
3.82 kB
"""
| **Abbreviation** | **Description** |
|------------------|-----------------|
| O | Outside of a named entity
| B-MIS | Beginning of a miscellaneous entity right after another miscellaneous entity
| I-MIS | Miscellaneous entity
| B-PER | Beginning of a person’s name right after another person’s name
| I-PER | Person’s name
| B-ORG | Beginning of an organization right after another organization
| I-ORG | Organization
| B-LOC | Beginning of a location right after another location
| I-LOC | Location
"""
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from enum import Enum
class DictKey(Enum):
ENTITY = 'entity'
SCORE = 'score'
INDEX = 'index'
WORD = 'word'
START = 'start'
END = 'end'
class NER:
def __init__(self, text_to_analyse):
"""
The Constructor for the Named Entity Recognition class.
:param text_to_analyse: The text in which to find named entities.
"""
self.tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
self.model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, grouped_entities=True)
if self.nlp is None:
raise ValueError("Unable to load pipeline from DSLIM BERT model")
self.text_to_analyse = text_to_analyse
self.results = self.nlp(text_to_analyse)
self.all_entities = self.get_list_of_entities()
self.unique_entities = self.unique_entities()
self.markdown = None
self.markdown_text = None
def get_entity_value(self, key: DictKey, item_index):
"""
Extracts the value for a specific key (as an Enum) from a specific dictionary item in the list.
:param key: DictKey Enum representing the key for which the value is required.
:param item_index: Index of the item in the list to process.
:return: Value for the given key in the specified dictionary item, or None if key is not found.
"""
if item_index < len(self.results):
return self.results[item_index].get(key.value)
else:
raise ValueError("The supplied list index is out of bounds")
def get_list_of_entities(self):
"""
Returns a list of all entities in the original text, in the order they appear. There may be repeated
entities in this list.
:return: A list of all entities in the original text.
"""
# create a list where each item is the value of word from each of the dictionaries in self.results
return [item.get(DictKey.WORD.value) for item in self.results]
def entity_markdown(self):
"""
Convert a string to markdown format and change the color of specified substrings to red.
"""
self.markdown = self.text_to_analyse
for substring in self.get_list_of_entities():
self.markdown = self.markdown.replace(substring, f'<span style = "color:red;">{substring}</span>')
self.markdown_text = self.markdown.replace('\n', ' \n') # Two spaces at the end of line for markdown new line
def unique_entities(self):
"""
Return a list of all unique entities in the original text.
:return: A list of unique entities.
"""
unique_set = set() # Sets are faster than lists for checking membership
# Create a new list to store the unique strings in order
unique_list = []
for string in self.all_entities:
if string not in unique_set:
unique_set.add(string)
unique_list.append(string)
return unique_list