|
import streamlit as st |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
feature_extractor = AutoFeatureExtractor.from_pretrained("xinyu1205/recognize-anything-plus-model") |
|
model = AutoModelForImageClassification.from_pretrained("xinyu1205/recognize-anything-plus-model") |
|
model.eval() |
|
return feature_extractor, model |
|
|
|
|
|
def predict(image, feature_extractor, model): |
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
|
|
top_5 = torch.topk(logits, k=5) |
|
return [model.config.id2label[i.item()] for i in top_5.indices[0]] |
|
|
|
|
|
st.title("RAM++ Image Tagging") |
|
|
|
feature_extractor, model = load_model() |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
if st.button('Get Tags'): |
|
tags = predict(image, feature_extractor, model) |
|
st.write("Predicted Tags:") |
|
st.write(", ".join(tags)) |