import torch from model import SimpleClassifier class HFModel: def __init__(self): self.model = SimpleClassifier() self.model.load_state_dict(torch.load("simple_model.pt")) self.model.eval() def __call__(self, inputs): # Expect a list of lists with 10 numbers X = torch.tensor(inputs, dtype=torch.float32) with torch.no_grad(): logits = self.model(X) preds = torch.argmax(logits, dim=1).tolist() return [{"label": "positive" if p == 1 else "negative", "score": float(logits[i][p])} for i, p in enumerate(preds)]