LinkBERT / handler.py
dejanseo's picture
Update handler.py
173d81c verified
raw
history blame
2.14 kB
from transformers import BertForTokenClassification, BertTokenizer, AutoConfig
import torch
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path: str = "dejanseo/LinkBERT"):
# Load the configuration from the saved model
self.config = AutoConfig.from_pretrained(path)
# Make sure to specify the correct model name for bert-large-cased
# Adjust num_labels according to your model's configuration
self.model = BertForTokenClassification.from_pretrained(
path,
config=self.config
)
self.model.eval() # Set model to evaluation mode
# Load the tokenizer for bert-large-cased
self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Extract input text from the request
inputs = data.get("inputs", "")
# Tokenize the inputs
inputs_tensor = self.tokenizer(inputs, return_tensors="pt", add_special_tokens=True)
input_ids = inputs_tensor["input_ids"]
# Run the model
with torch.no_grad():
outputs = self.model(input_ids)
predictions = torch.argmax(outputs.logits, dim=-1)
# Process the predictions to generate readable output
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens
predictions = predictions[0][1:-1].tolist()
# Reconstruct the text with annotations for token classification
result = []
for token, pred in zip(tokens, predictions):
if pred == 1: # Adjust this based on your classification needs
result.append(f"<u>{token}</u>")
else:
result.append(token)
reconstructed_text = " ".join(result).replace(" ##", "")
# Return the processed text in a structured format
return [{"text": reconstructed_text}]
# Note: Ensure the path "dejanseo/LinkBERT" is correctly pointing to your model's location
# If the model is locally saved, adjust the path accordingly