|
import streamlit as st |
|
from PIL import Image |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="Balajim57/zero-shot-vitb32") |
|
|
|
def predict(image, prompt1, prompt2, prompt3): |
|
|
|
results = classifier(image, [prompt1, prompt2, prompt3]) |
|
return results["labels"] |
|
|
|
|
|
st.title("Zero-Shot Image Classification") |
|
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
prompt1 = st.text_input("Prompt 1") |
|
prompt2 = st.text_input("Prompt 2") |
|
prompt3 = st.text_input("Prompt 3") |
|
if uploaded_image: |
|
image = Image.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
if st.button("Classify"): |
|
results = predict(image, prompt1, prompt2, prompt3) |
|
st.write("Classification Results:") |
|
st.write(results) |