|
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
model_name = "SCANSKY/distilbertTourism-multilingual-rclassifier" |
|
|
model = None |
|
|
|
|
|
def load_model_components(): |
|
|
"""Load the model and tokenizer once at startup""" |
|
|
global tokenizer, model |
|
|
tokenizer = DistilBertTokenizer.from_pretrained(model_name) |
|
|
model = DistilBertForSequenceClassification.from_pretrained(model_name) |
|
|
model.eval() |
|
|
print("Model and tokenizer loaded successfully.") |
|
|
|
|
|
|
|
|
load_model_components() |
|
|
|
|
|
def predict_relevance(text): |
|
|
"""Predict whether a text is relevant or not""" |
|
|
if not text.strip(): |
|
|
return {"error": "Empty text provided."} |
|
|
|
|
|
inputs = tokenizer( |
|
|
text, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=64, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
model.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
probs = torch.softmax(outputs.logits, dim=1) |
|
|
predicted_class = torch.argmax(probs).item() |
|
|
confidence = probs[0][predicted_class].item() |
|
|
|
|
|
return { |
|
|
"prediction": predicted_class, |
|
|
"confidence": float(confidence) * 100, |
|
|
"text": text |
|
|
} |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir=None): |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
def preprocess(self, data): |
|
|
|
|
|
text = data.get("inputs", "") |
|
|
|
|
|
lines = [line.strip() for line in text.split('\n') if line.strip()] |
|
|
return lines |
|
|
|
|
|
def inference(self, lines): |
|
|
results = [] |
|
|
for line in lines: |
|
|
result = predict_relevance(line) |
|
|
results.append(result) |
|
|
return results |
|
|
|
|
|
def postprocess(self, outputs): |
|
|
processed_results = [] |
|
|
for output in outputs: |
|
|
if "error" in output: |
|
|
processed_results.append({ |
|
|
"text": output.get("text", ""), |
|
|
"error": output["error"], |
|
|
"confidence": 0 |
|
|
}) |
|
|
else: |
|
|
processed_results.append({ |
|
|
"text": output["text"], |
|
|
"confidence": output["confidence"], |
|
|
"relevance": "RELEVANT" if output["prediction"] == 1 else "IRRELEVANT" |
|
|
}) |
|
|
return processed_results |
|
|
|
|
|
def __call__(self, data): |
|
|
|
|
|
lines = self.preprocess(data) |
|
|
outputs = self.inference(lines) |
|
|
return self.postprocess(outputs) |