g8a9's picture
Update app.py
a5e4816
raw
history blame contribute delete
No virus
2.35 kB
import streamlit as st
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer
import requests
from PIL import Image
import torch
CHECKPOINT = "g8a9/vit-geppetto-captioning"
@st.cache
def get_model():
model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT)
return model
feature_extractor = AutoFeatureExtractor.from_pretrained(CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
st.title("Image Captioning with ViT & GePpeTto ๐Ÿ‡ฎ๐Ÿ‡น")
st.sidebar.markdown("## Generation parameters")
max_length = st.sidebar.number_input("Max length", value=20, min_value=1)
no_repeat_ngram_size = st.sidebar.number_input("no repeat ngrams size", value=2, min_value=1)
num_return_sequences = st.sidebar.number_input("Generated sequences", value=3, min_value=1)
gen_mode = st.sidebar.selectbox("Generation mode", ["beam search", "sampling"])
if gen_mode == "beam search":
num_beams = st.sidebar.number_input("Beam size", value=5, min_value=1)
early_stopping = st.sidebar.checkbox("Early stopping", value=True)
gen_params = {
"num_beams": num_beams,
"early_stopping": early_stopping
}
elif gen_mode == "sampling":
do_sample = True
top_k = st.sidebar.number_input("top_k", value=30, min_value=0)
top_p = st.sidebar.number_input("top_p", value=0, min_value=0)
temperature = st.sidebar.number_input("temperature", value=0.7, min_value=0.0)
gen_params = {
"do_sample": do_sample,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature
}
def generate_caption(url):
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
inputs = feature_extractor(image, return_tensors="pt")
model = get_model()
model.eval()
generated_ids = model.generate(
inputs["pixel_values"],
max_length=20,
no_repeat_ngram_size=2,
num_return_sequences=3,
**gen_params
)
captions = tokenizer.batch_decode(
generated_ids,
skip_special_tokens=True,
)
return captions[0]
url = st.text_input(
"Insert your URL", "https://iheartcats.com/wp-content/uploads/2015/08/c84.jpg"
)
st.image(url)
if st.button("Run captioning"):
with st.spinner("Processing image..."):
caption = generate_caption(url)
st.text(caption)