image-retrieval / src /modules /feature_extractor.py
ABAO77's picture
Upload feature_extractor.py
5147546 verified
import torchvision.models.feature_extraction
import torchvision
import os
import torch
import onnx
import onnxruntime
import numpy as np
from .config_extractor import MODEL_CONFIG
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
class FeatureExtractor:
"""Class for extracting features from images using a pre-trained model"""
def __init__(self, base_model, onnx_path=None):
# set the base model
self.base_model = base_model
# get the number of features
self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
# get the feature layer name
self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"]
# Set default ONNX path if not provided
if onnx_path is None:
onnx_path = f"model/{base_model}_feature_extractor.onnx"
self.onnx_path = onnx_path
self.onnx_session = None
# Initialize transforms (needed for both ONNX and PyTorch)
_, self.transforms = self.init_model(base_model)
# Check if ONNX model exists
if os.path.exists(onnx_path):
print(f"Loading existing ONNX model from {onnx_path}")
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
else:
print(
f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
)
# Initialize PyTorch model
self.model, _ = self.init_model(base_model)
self.model.eval()
self.device = torch.device("cpu")
self.model.to(self.device)
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
# Convert to ONNX
self.convert_to_onnx(onnx_path)
# Load the newly created ONNX model
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
print(f"Successfully created and loaded ONNX model from {onnx_path}")
def init_model(self, base_model):
"""Initialize the model for feature extraction
Args:
base_model: str, the name of the base model
Returns:
model: torch.nn.Module, the feature extraction model
transforms: torchvision.transforms.Compose, the image transformations
"""
if base_model not in MODEL_CONFIG:
raise ValueError(f"Invalid base model: {base_model}")
# get the model and weights
weights = MODEL_CONFIG[base_model]["weights"]
model = torchvision.models.feature_extraction.create_feature_extractor(
MODEL_CONFIG[base_model]["model"](weights=weights),
[MODEL_CONFIG[base_model]["feat_layer"]],
)
# get the image transformations
transforms = weights.transforms()
return model, transforms
def extract_features(self, img):
"""Extract features from an image
Args:
img: PIL.Image, the input image
Returns:
output: torch.Tensor, the extracted features
"""
# apply transformations
x = self.transforms(img)
# add batch dimension
x = x.unsqueeze(0)
# Convert to numpy for ONNX Runtime
x_numpy = x.numpy()
# Run inference with ONNX Runtime
print("Running inference with ONNX Runtime")
output = self.onnx_session.run(
None,
{'input': x_numpy}
)[0]
# Convert back to torch tensor
output = torch.from_numpy(output)
return output
def convert_to_onnx(self, save_path):
"""Convert the model to ONNX format and save it
Args:
save_path: str, the path to save the ONNX model
Returns:
None
"""
# Create a dummy input tensor
dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
# Export the model
torch.onnx.export(
self.model,
dummy_input,
save_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
# Verify the exported model
onnx_model = onnx.load(save_path)
onnx.checker.check_model(onnx_model)
print(f"ONNX model saved to {save_path}")