LinkBERT / handler.py
dejanseo's picture
Update handler.py
9a43936 verified
from transformers import BertForTokenClassification, BertTokenizer, AutoConfig
import torch
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path: str = "dejanseo/LinkBERT"):
# Load the configuration from the saved model
self.config = AutoConfig.from_pretrained(path)
self.model = BertForTokenClassification.from_pretrained(
path,
config=self.config
)
self.model.eval() # Set model to evaluation mode
self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")
def split_into_chunks(self, text: str, max_length: int = 510) -> List[str]:
"""
Splits the input text into manageable chunks for the tokenizer.
"""
tokens = self.tokenizer.tokenize(text)
chunk_texts = []
for i in range(0, len(tokens), max_length):
chunk = tokens[i:i+max_length]
chunk_texts.append(self.tokenizer.convert_tokens_to_string(chunk))
return chunk_texts
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs", "")
# Split input text into chunks
chunks = self.split_into_chunks(inputs)
all_results = [] # List to store results from each chunk
for chunk in chunks:
inputs_tensor = self.tokenizer(chunk, return_tensors="pt", add_special_tokens=True)
input_ids = inputs_tensor["input_ids"]
with torch.no_grad():
outputs = self.model(input_ids)
predictions = torch.argmax(outputs.logits, dim=-1)
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens
predictions = predictions[0][1:-1].tolist()
# Improved reconstruction to handle "##" artifacts
reconstructed_text = ""
for token, pred in zip(tokens, predictions):
if not token.startswith("##"):
reconstructed_text += " " + token if reconstructed_text else token
else:
reconstructed_text += token[2:] # Remove "##" and append
if pred == 1: # Example condition, adjust as needed
reconstructed_text = reconstructed_text.strip() + "<u>" + token + "</u>"
all_results.append(reconstructed_text.strip())
# Join the results from each chunk
final_text = " ".join(all_results)
# Return the processed text in a structured format
return [{"text": final_text}]