laion-nllb / app.py
visheratin's picture
Upload 4 files
f04d812
raw
history blame contribute delete
No virus
3.7 kB
import random
import onnxruntime
import pandas as pd
import plotly.express as px
import streamlit as st
import torch
from lang_map import langs
from PIL import Image
from transformers import AutoTokenizer, CLIPProcessor
st.set_page_config(layout="wide")
options = list(langs.keys())
class SessionState:
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
def get_state(**kwargs):
if "session_state" not in st.session_state:
st.session_state["session_state"] = SessionState(**kwargs)
return st.session_state["session_state"]
def add_selectbox_and_input(key):
col1, col2 = st.columns(2)
with col1:
select = st.selectbox("Select a language", options, key=f"{key}_select")
with col2:
user_input = st.text_input("Input text", key=f"{key}_text")
state.inputs[key] = (select, user_input)
state = get_state(count=1, inputs={})
st.title("Zero-shot image classification with CLIP in 201 languages")
col1, col2 = st.columns(2)
image: Image.Image = None
with col1:
st.subheader("Image")
uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image.", use_column_width=True)
def process():
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
onnx_path = "model-quant.onnx"
ort_session = onnxruntime.InferenceSession(onnx_path, session_options)
processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
).image_processor
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
image_inputs = processor(images=image, return_tensors="pt")
classes = []
languages = []
for key, value in state.inputs.items():
languages.append(str(value[0]))
classes.append(str(value[1]))
languages = [langs[lang] for lang in languages]
input_ids = []
attention_mask = []
for i, _ in enumerate(languages):
tokenizer.set_src_lang_special_tokens(languages[i])
input = tokenizer.batch_encode_plus(
[classes[i]],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=100,
)
input_ids.append(input["input_ids"])
attention_mask.append(input["attention_mask"])
input_ids = torch.concat(input_ids, dim=0)
attention_mask = torch.concat(attention_mask, dim=0)
ort_inputs = {
"pixel_values": image_inputs["pixel_values"].numpy(),
"input_ids": input_ids.numpy(),
"attention_mask": attention_mask.numpy(),
}
ort_outputs = ort_session.run(None, ort_inputs)
logits = torch.tensor(ort_outputs[0])
probabilities = logits.softmax(dim=-1).squeeze().detach().numpy()
chart_data = pd.DataFrame({"Class": classes, "Probability": probabilities})
chart_data = chart_data.sort_values(by=["Probability"], ascending=True)
fig = px.bar(chart_data, x="Probability", y="Class", orientation="h")
with col2:
st.subheader("Predictions")
st.write(fig)
with col2:
st.subheader("Classes")
add_selectbox_and_input("Input 1")
for i in range(2, state.count + 1):
add_selectbox_and_input(f"Input {i}")
if st.button("Add class"):
state.count += 1
add_selectbox_and_input(f"Input {state.count}")
st.markdown("""---""")
if st.button("Generate"):
with st.spinner("Processing the data"):
process()