| | import torch |
| | import torch.nn as nn |
| | from torchvision import transforms |
| | from PIL import Image |
| | import io |
| | import base64 |
| | from typing import Dict, List, Any |
| | import timm |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """ |
| | Initialize handler with model path |
| | """ |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | self.class_names = ['Invalid', 'SDTI', 'Stage_I', 'Stage_II', 'Stage_III', 'Stage_IV', 'Unstageable'] |
| | |
| | |
| | self.model = timm.create_model('rexnet_150', pretrained=False, num_classes=7) |
| | |
| | |
| | model_path = f"{path}/pytorch_model.bin" if path else "pytorch_model.bin" |
| | state_dict = torch.load(model_path, map_location=self.device) |
| | self.model.load_state_dict(state_dict) |
| | self.model.to(self.device) |
| | self.model.eval() |
| | |
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Process inference request |
| | """ |
| | |
| | inputs = data.pop("inputs", data) |
| | |
| | |
| | if isinstance(inputs, dict) and "image" in inputs: |
| | image_data = inputs["image"] |
| | elif isinstance(inputs, str): |
| | image_data = inputs |
| | else: |
| | raise ValueError("Invalid input format. Expected {'image': base64_string} or base64_string") |
| | |
| | |
| | try: |
| | image_bytes = base64.b64decode(image_data) |
| | image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| | except Exception as e: |
| | raise ValueError(f"Failed to decode image: {str(e)}") |
| | |
| | |
| | image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(image_tensor) |
| | probabilities = torch.nn.functional.softmax(outputs, dim=1) |
| | |
| | |
| | top3_prob, top3_indices = torch.topk(probabilities, 3) |
| | |
| | |
| | predictions = [] |
| | for i in range(3): |
| | predictions.append({ |
| | "label": self.class_names[top3_indices[0][i].item()], |
| | "score": float(top3_prob[0][i].item()) |
| | }) |
| | |
| | |
| | all_probs = {} |
| | for i, class_name in enumerate(self.class_names): |
| | all_probs[class_name] = float(probabilities[0][i].item()) |
| | |
| | return [{ |
| | "predictions": predictions, |
| | "probabilities": all_probs, |
| | "predicted_class": predictions[0]["label"], |
| | "confidence": predictions[0]["score"] |
| | }] |