hello-universe
commited on
Commit
โข
84c4b50
1
Parent(s):
ffb49f0
Add app, model loader, requirements.txt
Browse files- app.py +39 -0
- ram_plus_model.py +24 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
5 |
+
|
6 |
+
# ๋ชจ๋ธ ๋ฐ ์ค์ ๋ก๋
|
7 |
+
@st.cache_resource
|
8 |
+
def load_model():
|
9 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("xinyu1205/recognize-anything-plus-model")
|
10 |
+
model = AutoModelForImageClassification.from_pretrained("xinyu1205/recognize-anything-plus-model")
|
11 |
+
model.eval()
|
12 |
+
return feature_extractor, model
|
13 |
+
|
14 |
+
# ์์ธก ํจ์
|
15 |
+
def predict(image, feature_extractor, model):
|
16 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
17 |
+
with torch.no_grad():
|
18 |
+
outputs = model(**inputs)
|
19 |
+
|
20 |
+
logits = outputs.logits
|
21 |
+
# ์์ 5๊ฐ ํ๊ทธ ๋ฐํ
|
22 |
+
top_5 = torch.topk(logits, k=5)
|
23 |
+
return [model.config.id2label[i.item()] for i in top_5.indices[0]]
|
24 |
+
|
25 |
+
# Streamlit ์ฑ
|
26 |
+
st.title("RAM++ Image Tagging")
|
27 |
+
|
28 |
+
feature_extractor, model = load_model()
|
29 |
+
|
30 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
|
31 |
+
|
32 |
+
if uploaded_file is not None:
|
33 |
+
image = Image.open(uploaded_file)
|
34 |
+
st.image(image, caption='Uploaded Image', use_column_width=True)
|
35 |
+
|
36 |
+
if st.button('Get Tags'):
|
37 |
+
tags = predict(image, feature_extractor, model)
|
38 |
+
st.write("Predicted Tags:")
|
39 |
+
st.write(", ".join(tags))
|
ram_plus_model.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class RAMPlusModel:
|
6 |
+
def __init__(self):
|
7 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained("xinyu1205/recognize-anything-plus-model")
|
8 |
+
self.model = AutoModelForImageClassification.from_pretrained("xinyu1205/recognize-anything-plus-model")
|
9 |
+
self.model.eval()
|
10 |
+
|
11 |
+
def predict(self, image):
|
12 |
+
inputs = self.feature_extractor(images=image, return_tensors="pt")
|
13 |
+
with torch.no_grad():
|
14 |
+
outputs = self.model(**inputs)
|
15 |
+
|
16 |
+
logits = outputs.logits
|
17 |
+
predicted_classes = logits.argmax(-1)
|
18 |
+
|
19 |
+
# ์์ 5๊ฐ ํ๊ทธ ๋ฐํ (์ด ๋ถ๋ถ์ ๋ชจ๋ธ์ ์ค์ ์ถ๋ ฅ์ ๋ฐ๋ผ ์กฐ์ ํ์)
|
20 |
+
top_5 = torch.topk(logits, k=5)
|
21 |
+
return [self.model.config.id2label[i.item()] for i in top_5.indices[0]]
|
22 |
+
|
23 |
+
# ๋ชจ๋ธ ์ธ์คํด์ค ์์ฑ
|
24 |
+
model = RAMPlusModel()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
Pillow
|