Charm_10 / vision_processing.py
GeminiFan207's picture
Upload 12 files
18fa92b verified
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import logging
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VisionProcessingModel(nn.Module):
def __init__(self, model_name="resnet50", num_classes=1000, top_k=5):
super(VisionProcessingModel, self).__init__()
# Load pre-trained model
self.model = self._load_pretrained_model(model_name, num_classes)
self.model.eval() # Set the model to evaluation mode
# Automatically detect device (GPU/CPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
logger.info(f"Model loaded on device: {self.device}")
# Number of top predictions to return
self.top_k = top_k
# Image transformations
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def _load_pretrained_model(self, model_name, num_classes):
"""Helper function to load a pre-trained model."""
if model_name == "resnet50":
return models.resnet50(pretrained=True)
elif model_name == "efficientnet_b0":
return models.efficientnet_b0(pretrained=True)
elif model_name == "mobilenet_v2":
return models.mobilenet_v2(pretrained=True)
else:
raise ValueError(f"Unsupported model: {model_name}")
def forward(self, image):
"""Forward pass through the model."""
image = image.to(self.device)
with torch.no_grad():
outputs = self.model(image)
return outputs
def process_image(self, image_path):
"""Process an image and get predictions."""
try:
# Load image
image = Image.open(image_path).convert('RGB')
# Apply transformations and add batch dimension
image = self.transform(image).unsqueeze(0)
# Get predictions from model
outputs = self.forward(image)
# Convert raw output to probabilities (Softmax)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top-k categories
top_k_prob, top_k_catid = torch.topk(probabilities, self.top_k)
return top_k_prob, top_k_catid
except Exception as e:
logger.error(f"Error processing image: {e}")
return None, None
def get_category_labels(self, category_ids):
"""Map category IDs to human-readable labels."""
# Load ImageNet class labels
labels_path = os.getenv("IMAGENET_LABELS_PATH", "path/to/imagenet_labels.txt")
with open(labels_path, "r") as f:
labels = [line.strip() for line in f.readlines()]
# Map category IDs to labels
return [labels[cat_id] for cat_id in category_ids]
def enhance_vision_processing(self, image_path):
"""Enhance vision capabilities by extracting top-k predictions."""
top_k_prob, top_k_catid = self.process_image(image_path)
if top_k_prob is not None and top_k_catid is not None:
# Convert tensors to lists
top_k_prob = top_k_prob.tolist()
top_k_catid = top_k_catid.tolist()
# Get human-readable labels
category_labels = self.get_category_labels(top_k_catid)
return top_k_prob, top_k_catid, category_labels
else:
return None, None, None