| | """ResNet inference service implementation.""" |
| |
|
| | import base64 |
| | import os |
| | from io import BytesIO |
| |
|
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from transformers import AutoImageProcessor, ResNetForImageClassification |
| |
|
| | from app.core.logging import logger |
| | from app.services.base import InferenceService |
| | from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse |
| |
|
| |
|
| | class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]): |
| | """ResNet-18 inference service for image classification.""" |
| |
|
| | def __init__(self, model_name: str = "microsoft/resnet-18"): |
| | self.model_name = model_name |
| | self.model = None |
| | self.processor = None |
| | self._is_loaded = False |
| | self.model_path = os.path.join("models", model_name) |
| | logger.info(f"Initializing ResNet service: {self.model_path}") |
| |
|
| | def load_model(self) -> None: |
| | if self._is_loaded: |
| | return |
| |
|
| | if not os.path.exists(self.model_path): |
| | raise FileNotFoundError(f"Model not found: {self.model_path}") |
| |
|
| | config_path = os.path.join(self.model_path, "config.json") |
| | if not os.path.exists(config_path): |
| | raise FileNotFoundError(f"Config not found: {config_path}") |
| |
|
| | logger.info(f"Loading model from {self.model_path}") |
| |
|
| | import warnings |
| | with warnings.catch_warnings(): |
| | warnings.filterwarnings("ignore", category=FutureWarning) |
| | self.processor = AutoImageProcessor.from_pretrained( |
| | self.model_path, local_files_only=True |
| | ) |
| | self.model = ResNetForImageClassification.from_pretrained( |
| | self.model_path, local_files_only=True |
| | ) |
| | assert self.model is not None |
| |
|
| | self._is_loaded = True |
| | logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") |
| |
|
| | def predict(self, request: ImageRequest) -> PredictionResponse: |
| | if not self.is_loaded: |
| | raise RuntimeError("model is not loaded") |
| | assert self.processor is not None |
| | assert self.model is not None |
| |
|
| | image_data = base64.b64decode(request.image.data) |
| | image = Image.open(BytesIO(image_data)) |
| |
|
| | if image.mode != 'RGB': |
| | image = image.convert('RGB') |
| |
|
| | inputs = self.processor(image, return_tensors="pt") |
| |
|
| | with torch.no_grad(): |
| | logits = self.model(**inputs).logits.squeeze() |
| |
|
| | |
| | |
| | |
| | logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist() |
| | |
| | |
| | x = image.width // 3 |
| | y = image.height // 3 |
| | |
| | mask = np.zeros((image.height, image.width), dtype=np.uint8) |
| | mask[y:(2*y), x:(2*x)] = 1 |
| | mask_obj = BinaryMask.from_numpy(mask) |
| |
|
| | return PredictionResponse( |
| | logprobs=logprobs, |
| | localizationMask=mask_obj, |
| | ) |
| |
|
| | @property |
| | def is_loaded(self) -> bool: |
| | return self._is_loaded |
| |
|