import streamlit as st from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer import requests from PIL import Image import torch CHECKPOINT = "g8a9/vit-geppetto-captioning" model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT) feature_extractor = AutoFeatureExtractor.from_pretrained(CHECKPOINT) tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) model.eval() def generate_caption(url): image = Image.open(requests.get(url, stream=True).raw).convert("RGB") inputs = feature_extractor(image, return_tensors="pt") generated_ids = model.generate( inputs["pixel_values"], max_length=20, num_beams=5, early_stopping=True, no_repeat_ngram_size=2, num_return_sequences=3, ) captions = tokenizer.batch_decode( generated_ids, skip_special_tokens=True, ) return captions[0] st.title("Captioning demo") url = st.text_input( "Insert your URL", "https://iheartcats.com/wp-content/uploads/2015/08/c84.jpg" ) st.image(url) if st.button("Run captioning"): with st.spinner("Processing image..."): caption = generate_caption(url) st.text(caption)