ram-plus-inference / ram_plus_model.py
hello-universe's picture
Add app, model loader, requirements.txt
84c4b50
raw
history blame
957 Bytes
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import torch
class RAMPlusModel:
def __init__(self):
self.feature_extractor = AutoFeatureExtractor.from_pretrained("xinyu1205/recognize-anything-plus-model")
self.model = AutoModelForImageClassification.from_pretrained("xinyu1205/recognize-anything-plus-model")
self.model.eval()
def predict(self, image):
inputs = self.feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
predicted_classes = logits.argmax(-1)
# ์ƒ์œ„ 5๊ฐœ ํƒœ๊ทธ ๋ฐ˜ํ™˜ (์ด ๋ถ€๋ถ„์€ ๋ชจ๋ธ์˜ ์‹ค์ œ ์ถœ๋ ฅ์— ๋”ฐ๋ผ ์กฐ์ • ํ•„์š”)
top_5 = torch.topk(logits, k=5)
return [self.model.config.id2label[i.item()] for i in top_5.indices[0]]
# ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
model = RAMPlusModel()