import streamlit as st import streamlit.components.v1 as components from PIL import Image import requests from predict import generate_text from model import load_model from streamlit_image_select import image_select # Configure Streamlit page st.set_page_config(page_title="Caption Machine", page_icon="📸") # Set Session model, image_transform, tokenizer = load_model() if 'model' not in st.session_state: st.session_state['model'] = model if 'image_transform' not in st.session_state: st.session_state['image_transform'] = image_transform if 'tokenizer' not in st.session_state: st.session_state['tokenizer'] = tokenizer # Force responsive layout for columns also on mobile st.write( """""", unsafe_allow_html=True, ) # Render Streamlit page st.title("Image Captioner") st.markdown( "This app utilizes OpenAI's [GPT-2](https://openai.com/research/better-language-models) and [CLIP](https://openai.com/research/clip) models to generate image captions. The model architecture was inspired by [ClipCap: CLIP Prefix for Image Captioning](https://arxiv.org/abs/2111.09734), which uses CLIP encoding as prefix and fine-tune GPT-2 model to generate the caption." ) # Select image or upload image select_file = image_select( label="Select a photo:", images=[ "https://farm5.staticflickr.com/4084/5093294428_2f50d54acb_z.jpg", "https://farm8.staticflickr.com/7044/6855243647_cd204d079c_z.jpg", "http://farm4.staticflickr.com/3016/2650267987_f478c8d682_z.jpg", "https://farm8.staticflickr.com/7249/6913786280_c145ecc433_z.jpg", ], # captions=["A cat", "Another cat", "Oh look, a cat!", "Guess what, a cat..."], ) st.markdown("
Or
", unsafe_allow_html=True) upload_file = st.file_uploader("Upload an image:", type=['png','jpg','jpeg']) # Checking the Format of the page if upload_file or select_file: img = None if upload_file: img = Image.open(upload_file) elif select_file: # st.text(select_file) img = Image.open(requests.get(select_file, stream=True).raw) st.image(img) # st.write("Image Uploaded Successfully") # gpt_model, tokenizer = load_gpt_model() with st.spinner('Generating caption...'): caption = generate_text(st.session_state['model'], img, st.session_state['tokenizer'], st.session_state['image_transform']) st.success(f"Result: {caption}") # Model information with st.expander("See model architecture"): st.markdown( """ Steps: 1. Feed image into CLIP Image Encoder to get image embedding 2. image embedding into text embedding shape 3. Feed Text into GPT-2 Text Embedder to get a text embedding 4. Concatenate two embeddings and feed into GPT-2 Attention Layers """) st.write(" \nModel Architecture: ") model_img = Image.open('./model.png') st.image(model_img, width=450)