from transformers import BertForTokenClassification, BertTokenizer, AutoConfig import torch from typing import Dict, List, Any class EndpointHandler: def __init__(self, path: str = "dejanseo/LinkBERT"): # Load the configuration from the saved model self.config = AutoConfig.from_pretrained(path) self.model = BertForTokenClassification.from_pretrained( path, config=self.config ) self.model.eval() # Set model to evaluation mode self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased") def split_into_chunks(self, text: str, max_length: int = 510) -> List[str]: """ Splits the input text into manageable chunks for the tokenizer. """ tokens = self.tokenizer.tokenize(text) chunk_texts = [] for i in range(0, len(tokens), max_length): chunk = tokens[i:i+max_length] chunk_texts.append(self.tokenizer.convert_tokens_to_string(chunk)) return chunk_texts def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs = data.get("inputs", "") # Split input text into chunks chunks = self.split_into_chunks(inputs) all_results = [] # List to store results from each chunk for chunk in chunks: inputs_tensor = self.tokenizer(chunk, return_tensors="pt", add_special_tokens=True) input_ids = inputs_tensor["input_ids"] with torch.no_grad(): outputs = self.model(input_ids) predictions = torch.argmax(outputs.logits, dim=-1) tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens predictions = predictions[0][1:-1].tolist() # Improved reconstruction to handle "##" artifacts reconstructed_text = "" for token, pred in zip(tokens, predictions): if not token.startswith("##"): reconstructed_text += " " + token if reconstructed_text else token else: reconstructed_text += token[2:] # Remove "##" and append if pred == 1: # Example condition, adjust as needed reconstructed_text = reconstructed_text.strip() + "" + token + "" all_results.append(reconstructed_text.strip()) # Join the results from each chunk final_text = " ".join(all_results) # Return the processed text in a structured format return [{"text": final_text}]