File size: 3,735 Bytes
18fa92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import logging
import os

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class VisionProcessingModel(nn.Module):
    def __init__(self, model_name="resnet50", num_classes=1000, top_k=5):
        super(VisionProcessingModel, self).__init__()
        
        # Load pre-trained model
        self.model = self._load_pretrained_model(model_name, num_classes)
        self.model.eval()  # Set the model to evaluation mode

        # Automatically detect device (GPU/CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        logger.info(f"Model loaded on device: {self.device}")

        # Number of top predictions to return
        self.top_k = top_k

        # Image transformations
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _load_pretrained_model(self, model_name, num_classes):
        """Helper function to load a pre-trained model."""
        if model_name == "resnet50":
            return models.resnet50(pretrained=True)
        elif model_name == "efficientnet_b0":
            return models.efficientnet_b0(pretrained=True)
        elif model_name == "mobilenet_v2":
            return models.mobilenet_v2(pretrained=True)
        else:
            raise ValueError(f"Unsupported model: {model_name}")

    def forward(self, image):
        """Forward pass through the model."""
        image = image.to(self.device)
        with torch.no_grad():
            outputs = self.model(image)
        return outputs

    def process_image(self, image_path):
        """Process an image and get predictions."""
        try:
            # Load image
            image = Image.open(image_path).convert('RGB')

            # Apply transformations and add batch dimension
            image = self.transform(image).unsqueeze(0)

            # Get predictions from model
            outputs = self.forward(image)

            # Convert raw output to probabilities (Softmax)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

            # Get top-k categories
            top_k_prob, top_k_catid = torch.topk(probabilities, self.top_k)

            return top_k_prob, top_k_catid
        except Exception as e:
            logger.error(f"Error processing image: {e}")
            return None, None

    def get_category_labels(self, category_ids):
        """Map category IDs to human-readable labels."""
        # Load ImageNet class labels
        labels_path = os.getenv("IMAGENET_LABELS_PATH", "path/to/imagenet_labels.txt")
        with open(labels_path, "r") as f:
            labels = [line.strip() for line in f.readlines()]

        # Map category IDs to labels
        return [labels[cat_id] for cat_id in category_ids]

    def enhance_vision_processing(self, image_path):
        """Enhance vision capabilities by extracting top-k predictions."""
        top_k_prob, top_k_catid = self.process_image(image_path)

        if top_k_prob is not None and top_k_catid is not None:
            # Convert tensors to lists
            top_k_prob = top_k_prob.tolist()
            top_k_catid = top_k_catid.tolist()

            # Get human-readable labels
            category_labels = self.get_category_labels(top_k_catid)

            return top_k_prob, top_k_catid, category_labels
        else:
            return None, None, None