mohitsha's picture
mohitsha HF staff
Update inference.py
81c7679 verified
raw history blame
No virus
1.19 kB
import requests
from PIL import Image
from optimum.amd.ryzenai import RyzenAIModelForImageClassification
from transformers import PretrainedConfig, pipeline
import timm
import torch
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
quantized_model_path = "mohitsha/timm-resnet18-onnx-quantized-ryzen"
# The path and name of the runtime configuration file. A default version of this file can be
# found in the voe-4.0-win_​amd64 folder of the Ryzen AI software installation package under
# the name vaip_​config.json
vaip_config = ".\\vaip_config.json"
model = RyzenAIModelForImageClassification.from_pretrained(quantized_model_path, vaip_config=vaip_config)
config = PretrainedConfig.from_pretrained(quantized_model_path)
# preprocess config
data_config = timm.data.resolve_data_config(pretrained_cfg=config.pretrained_cfg)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(image).unsqueeze(0)).logits # unsqueeze single image into batch of 1
top5_probabilities, top5_class_indices = torch.topk(torch.softmax(output, dim=1) * 100, k=5)
print(top5_class_indices)