|
import streamlit as st |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification |
|
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer |
|
from PIL import Image |
|
import torch |
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
import numpy as np |
|
|
|
|
|
generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased") |
|
model = AutoModelForSequenceClassification.from_pretrained("swangfr/distilbert-multi-label-amazon") |
|
|
|
|
|
st.title("Amazon Product Image classifier") |
|
st.write("Classification for 24 categories") |
|
|
|
|
|
file_name = st.file_uploader("Upload a product image file") |
|
|
|
multilabel = MultiLabelBinarizer() |
|
|
|
if file_name is not None: |
|
col1, col2 = st.columns(2) |
|
|
|
image = Image.open(file_name) |
|
col1.image(image, use_column_width=True) |
|
|
|
generation = generator(image) |
|
text = generation[0]['generated_text'] |
|
st.write(text) |
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
input_ids = inputs['input_ids'] |
|
prediction = model(input_ids=torch.tensor(input_ids)) |
|
|
|
|
|
sigmoid = torch.nn.Sigmoid() |
|
probs = sigmoid(prediction.logits[0].cpu()) |
|
prediction = np.zeros(probs.shape) |
|
prediction[np.where(probs>=0.3)] = 1 |
|
|
|
multilabel.fit([[ |
|
'Arts & Crafts', |
|
'Baby & Toddler Toys', |
|
'Building Toys', |
|
'Collectible Toys', |
|
'Dolls & Accessories', |
|
'Furniture', 'Games', |
|
'Games & Accessories', |
|
'Hobbies', 'Hobbies ', |
|
"Kids' Electronics", |
|
'Learning & Education', |
|
'Men', |
|
'Novelty & Gag Toys', |
|
'Party Supplies', |
|
'Play Vehicles', |
|
'PlayStation 4 ', |
|
'Rockets', |
|
'Sports & Outdoor Play', |
|
'Xbox One', |
|
'Clothing, Shoes & Jewelry ', |
|
'Home & Kitchen ', |
|
'Video Games ' |
|
'Toys & Games ' |
|
]]) |
|
|
|
multilabel.inverse_transform(prediction.reshape(1,-1)) |
|
|
|
st.write(multilabel.inverse_transform(prediction.reshape(1,-1))) |