ResNet-50 Embeddings Only
This is a modified version of a standard ResNet-50 architecture, where the final, fully connected layer that does the classification, has been removed.
This effectively gives you the embeddings.
NB: You may want to flatten the embeddings, as it'll be of shape (1, 20248, 1, 1)
otherwise.
Example
import onnxruntime
from PIL import Image
from torchvision import transforms
def load_and_preprocess_image(image_path):
# Define the same preprocessing as used in training
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Open the image file
img = Image.open(image_path)
# Preprocess the image
img_preprocessed = preprocess(img)
# Add batch dimension
return img_preprocessed.unsqueeze(0).numpy()
onnx_model_path = "resnet50_embeddings.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)
input_name = session.get_inputs()[0].name
# Load and preprocess an image (replace with your image path)
image_path = "disco-ball.jpg"
input_data = load_and_preprocess_image(image_path)
# Run inference
outputs = session.run(None, {input_name: input_data})
# The output should be a single tensor (the embeddings)
embeddings = outputs[0]
# Flatten the embeddings
embeddings = embeddings.reshape(embeddings.shape[0], -1)
Model tree for jxtc/resnet-50-embeddings
Base model
microsoft/resnet-50