cha0smagick's picture
Update app.py
5b568c0
raw history blame
No virus
1.36 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForImageCaptioning
import requests
from PIL import Image
import numpy as np
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("microsoft/beit-base-patch16-224-in21k")
model = AutoModelForImageCaptioning.from_pretrained("microsoft/beit-base-patch16-224-in21k")
def generate_caption(image_url):
# Get the image from the URL
image = Image.open(requests.get(image_url, stream=True).raw)
# Preprocess the image
input_array = np.array(image) / 255.0
input_array = np.transpose(input_array, (2, 0, 1))
input_ids = tokenizer(image_url, return_tensors="pt").input_ids
# Generate the caption
output = model.generate(input_ids, max_length=20)
caption = tokenizer.batch_decode(output, skip_special_tokens=True)
return caption[0]
def main():
# Create a sidebar for the user to input the image URL
st.sidebar.header("Image Caption Generator")
image_url = st.sidebar.text_input("Enter the URL of an image:")
# Generate the caption if the user clicks the button
if st.sidebar.button("Generate Caption"):
if image_url != "":
caption = generate_caption(image_url)
st.success(f"Caption: {caption}")
else:
st.error("Please enter a valid image URL.")
# Run the main function
if __name__ == "__main__":
main()