Spaces:
Runtime error
Runtime error
""" | |
# Copyright (c) 2022, salesforce.com, inc. | |
# All rights reserved. | |
# SPDX-License-Identifier: BSD-3-Clause | |
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import streamlit as st | |
from app import device, load_demo_image | |
from app.utils import load_model_cache | |
from lavis.processors import load_processor | |
from PIL import Image | |
def app(): | |
# ===== layout ===== | |
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"]) | |
sampling_method = st.sidebar.selectbox( | |
"Sampling method:", ["Beam search", "Nucleus sampling"] | |
) | |
st.markdown( | |
"<h1 style='text-align: center;'>Image Description Generation</h1>", | |
unsafe_allow_html=True, | |
) | |
instructions = """Try the provided image or upload your own:""" | |
file = st.file_uploader(instructions) | |
use_beam = sampling_method == "Beam search" | |
col1, col2 = st.columns(2) | |
if file: | |
raw_img = Image.open(file).convert("RGB") | |
else: | |
raw_img = load_demo_image() | |
col1.header("Image") | |
w, h = raw_img.size | |
scaling_factor = 720 / w | |
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) | |
col1.image(resized_image, use_column_width=True) | |
col2.header("Description") | |
cap_button = st.button("Generate") | |
# ==== event ==== | |
vis_processor = load_processor("blip_image_eval").build(image_size=384) | |
if cap_button: | |
if model_type.startswith("BLIP"): | |
blip_type = model_type.split("_")[1].lower() | |
model = load_model_cache( | |
"blip_caption", | |
model_type=f"{blip_type}_coco", | |
is_eval=True, | |
device=device, | |
) | |
img = vis_processor(raw_img).unsqueeze(0).to(device) | |
captions = generate_caption( | |
model=model, image=img, use_nucleus_sampling=not use_beam | |
) | |
col2.write("\n\n".join(captions), use_column_width=True) | |
def generate_caption( | |
model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5 | |
): | |
samples = {"image": image} | |
captions = [] | |
if use_nucleus_sampling: | |
for _ in range(5): | |
caption = model.generate( | |
samples, | |
use_nucleus_sampling=True, | |
max_length=max_length, | |
min_length=min_length, | |
top_p=0.9, | |
) | |
captions.append(caption[0]) | |
else: | |
caption = model.generate( | |
samples, | |
use_nucleus_sampling=False, | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
) | |
captions.append(caption[0]) | |
return captions | |