from typing import Tuple import string from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification import spacy import torch import gradio as gr class NER: prompt: str = """ Identify entities in the text having the following classes: {} Text: """ def __init__( self, model_name: str, sents_batch: int=10, tokens_limit: int=2048 ): self.sents_batch = sents_batch self.tokens_limit = tokens_limit self.nlp: spacy.Language = spacy.load( 'en_core_web_sm', disable = ['lemmatizer', 'parser', 'tagger', 'ner'] ) self.nlp.add_pipe('sentencizer') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name) self.pipeline = pipeline( "ner", model=model, tokenizer=self.tokenizer, aggregation_strategy='first', batch_size=12, device=device ) def get_last_sentence_id(self, i: int, sentences_len: int) -> int: return min(i + self.sents_batch, sentences_len) - 1 def chunkanize(self, text: str) -> Tuple[list[str], list[int]]: doc = self.nlp(text) chunks = [] starts = [] sentences = list(doc.sents) for i in range(0, len(sentences), self.sents_batch): start = sentences[i].start_char starts.append(start) last_sentence = self.get_last_sentence_id(i, len(sentences)) end = sentences[last_sentence].end_char chunks.append(text[start:end]) return chunks, starts def get_inputs( self, chunks: list[str], labels: list[str] ) -> Tuple[list[str], list[int]]: inputs = [] prompts_lens = [] for label in labels: prompt = self.prompt.format(label) prompts_lens.append(len(prompt)) for chunk in chunks: inputs.append(prompt + chunk) return inputs, prompts_lens @classmethod def clean_span( cls, start: int, end: int, span: str ) -> Tuple[int, int, str]: if len(span) >= 1: if span[0] in string.punctuation: return cls.clean_span(start+1, end, span[1:]) if span[-1] in string.punctuation: return cls.clean_span(start, end-1, span[:-1]) return start, end, span.strip() def predict( self, text: str, inputs: list[str], labels: list[str], chunks_starts: list[int], prompts_lens: list[int], threshold: float ) -> list[dict[str, any]]: outputs = [] for id, output in enumerate(self.pipeline(inputs)): label = labels[id//len(chunks_starts)] shift = chunks_starts[id%len(chunks_starts)] - prompts_lens[id//len(chunks_starts)] for ent in output: start = ent['start'] + shift + 1 end = ent['end'] + shift start, end, span = self.clean_span(start, end, text[start:end]) if not span: continue if ent['score'] >= threshold: outputs.append({ 'span': span, 'start': start, 'end': end, 'entity': label }) return outputs def check_text(self, text: str) -> None: if not text: raise gr.Error('No text provided. Please provide text.') def check_labels(self, labels: list[str]) -> None: if not labels: raise gr.Error( 'No labels provided. Please provide labels.' ' Multiple labels should be divided by commas.' ' See examples below.' ) def check_tokens_limit(self, inputs: list[str]) -> None: tokens = 0 for input_ in inputs: tokens += len(self.tokenizer.encode(input_)) if tokens > self.tokens_limit: raise gr.Error( 'Too many tokens! Please reduce size of text or amount of labels.' f' Max tokens count is: {self.tokens_limit}.' ) def process( self, labels: str, text: str, threshold: float=0. ) -> dict[str, any]: labels_list = list({ l for label in labels.split(',') if (l:=label.strip()) }) self.check_labels(labels_list) self.check_text(text) chunks, chunks_starts = self.chunkanize(text) inputs, prompts_lens = self.get_inputs(chunks, labels_list) self.check_tokens_limit(inputs) outputs = self.predict( text, inputs, labels_list, chunks_starts, prompts_lens, threshold ) return {"text": text, "entities": outputs}