Spaces:
Sleeping
Sleeping
import torch | |
import clip | |
import PIL.Image | |
import skimage.io as io | |
import streamlit as st | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup | |
from model import preprocess,clip_model,generate2,ClipCaptionModel | |
#model loading code | |
device = "cpu" | |
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
prefix_length = 10 | |
model = ClipCaptionModel(prefix_length) | |
model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu'))) | |
model = model.eval() | |
coco_model = ClipCaptionModel(prefix_length) | |
coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu'))) | |
model = model.eval() | |
def ui(): | |
st.markdown("# Image Captioning") | |
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg']) | |
if uploaded_file is not None: | |
image = io.imread(uploaded_file) | |
pil_image = PIL.Image.fromarray(image) | |
image = preprocess(pil_image).unsqueeze(0).to(device) | |
option = st.selectbox('Please select the Model',('Model', 'COCO Model')) | |
if option=='Model': | |
with torch.no_grad(): | |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) | |
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) | |
st.image(uploaded_file, width = 500, channels = 'RGB') | |
st.markdown("**PREDICTION:** " + generated_text_prefix) | |
elif option=='COCO Model': | |
with torch.no_grad(): | |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) | |
generated_text_prefix = generate2(coco_model, tokenizer, embed=prefix_embed) | |
st.image(uploaded_file, width = 500, channels = 'RGB') | |
st.markdown("**PREDICTION:** " + generated_text_prefix) | |
if __name__ == '__main__': | |
ui() | |