|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
from transformers import BertModel, BertTokenizer, BertPreTrainedModel |
|
|
|
|
|
class BertForRegression(BertPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.bert = BertModel(config) |
|
|
self.regressor = nn.Linear(config.hidden_size, 1) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None): |
|
|
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
|
logits = self.regressor(outputs.pooler_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.MSELoss() |
|
|
loss = loss_fct(logits.squeeze(), labels.float()) |
|
|
|
|
|
return (loss, logits) if loss is not None else logits |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path: str): |
|
|
self.model = BertForRegression.from_pretrained(path) |
|
|
self.tokenizer = BertTokenizer.from_pretrained(path) |
|
|
|
|
|
def __call__(self, data): |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
if isinstance(data, dict): |
|
|
text = data.get("inputs", "") |
|
|
else: |
|
|
text = data |
|
|
if not isinstance(text, str): |
|
|
raise ValueError("Input text must be a string under the 'inputs' key.") |
|
|
|
|
|
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model(**inputs) |
|
|
prediction = logits[0].item() |
|
|
prediction = math.trunc(prediction * 100) / 100 |
|
|
return str(prediction) |