from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
from PIL import Image | |
import torch | |
class RAMPlusModel: | |
def __init__(self): | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained("xinyu1205/recognize-anything-plus-model") | |
self.model = AutoModelForImageClassification.from_pretrained("xinyu1205/recognize-anything-plus-model") | |
self.model.eval() | |
def predict(self, image): | |
inputs = self.feature_extractor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits | |
predicted_classes = logits.argmax(-1) | |
# ์์ 5๊ฐ ํ๊ทธ ๋ฐํ (์ด ๋ถ๋ถ์ ๋ชจ๋ธ์ ์ค์ ์ถ๋ ฅ์ ๋ฐ๋ผ ์กฐ์ ํ์) | |
top_5 = torch.topk(logits, k=5) | |
return [self.model.config.id2label[i.item()] for i in top_5.indices[0]] | |
# ๋ชจ๋ธ ์ธ์คํด์ค ์์ฑ | |
model = RAMPlusModel() |