|
|
| import torch |
| import json |
| import os |
| from transformers import AutoTokenizer, AutoModel |
| import torch.nn as nn |
|
|
| class EmailClassificationModel(nn.Module): |
| def __init__(self, model_name, num_categories, num_subcategories, dropout=0.1): |
| super().__init__() |
| self.backbone = AutoModel.from_pretrained(model_name) |
| self.dropout = nn.Dropout(dropout) |
| self.category_head = nn.Linear(self.backbone.config.hidden_size, num_categories) |
| self.subcategory_head = nn.Linear(self.backbone.config.hidden_size, num_subcategories) |
| |
| def forward(self, input_ids, attention_mask): |
| outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) |
| pooled_output = outputs.last_hidden_state[:, 0] |
| pooled_output = self.dropout(pooled_output) |
| |
| category_logits = self.category_head(pooled_output) |
| subcategory_logits = self.subcategory_head(pooled_output) |
| |
| return { |
| 'category_logits': category_logits, |
| 'subcategory_logits': subcategory_logits |
| } |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| checkpoint = torch.load(os.path.join(path, "model_checkpoint.pt"), map_location="cpu") |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(path, "tokenizer")) |
| |
| |
| config = checkpoint['model_config'] |
| self.model = EmailClassificationModel( |
| model_name=config['model_name'], |
| num_categories=config['num_categories'], |
| num_subcategories=config['num_subcategories'] |
| ) |
| |
| |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| self.model.eval() |
| |
| |
| self.categories = checkpoint['categories'] |
| self.subcategories = checkpoint['subcategories'] |
| self.max_length = config['max_length'] |
| |
| def __call__(self, data): |
| try: |
| inputs = data.get("inputs", "") |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| |
| |
| encoded = self.tokenizer( |
| inputs, |
| truncation=True, |
| padding=True, |
| max_length=self.max_length, |
| return_tensors="pt" |
| ) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(**encoded) |
| |
| category_preds = torch.argmax(outputs['category_logits'], dim=1) |
| subcategory_preds = torch.argmax(outputs['subcategory_logits'], dim=1) |
| |
| category_probs = torch.softmax(outputs['category_logits'], dim=1) |
| subcategory_probs = torch.softmax(outputs['subcategory_logits'], dim=1) |
| |
| category_confidence = torch.max(category_probs, dim=1)[0] |
| subcategory_confidence = torch.max(subcategory_probs, dim=1)[0] |
| |
| |
| results = [] |
| for i in range(len(inputs)): |
| result = { |
| "text": inputs[i], |
| "category": { |
| "label": self.categories[category_preds[i].item()], |
| "confidence": round(category_confidence[i].item(), 4) |
| }, |
| "subcategory": { |
| "label": self.subcategories[subcategory_preds[i].item()], |
| "confidence": round(subcategory_confidence[i].item(), 4) |
| } |
| } |
| results.append(result) |
| |
| return results[0] if len(results) == 1 else results |
| |
| except Exception as e: |
| return {"error": f"Prediction failed: {str(e)}"} |
|
|