virtex-redcaps / app.py
zamborg's picture
app start
49c0315
raw history blame
No virus
3.37 kB
import streamlit as st
from huggingface_hub import snapshot_download
from PIL import Image
import argparse
import json
import os
from typing import Any, Dict, List
from loguru import logger
import torch
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
import wordsegment as ws
from virtex.config import Config
from virtex.data import ImageDirectoryDataset
from virtex.factories import TokenizerFactory, PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser
CONFIG_PATH = "config.yaml"
MODEL_PATH = "checkpoint_last5.pth"
# x = st.slider("Select a value")
# st.write(x, "squared is", x * x)
class ImageLoader():
def __init__(self):
self.transformer = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor()])
def load(self, im_path, prompt):
im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
return {"image": im, "decode_prompt": prompt}
class VirTexModel():
def __init__(self):
self.config = Config(CONFIG_PATH)
ws.load()
self.device = 'cpu'
self.tokenizer = TokenizerFactory.from_config(self.config)
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
CheckpointManager(model=self.model).load("./checkpoint_last5.pth")
self.model.eval()
self.loader = ImageLoader()
def predict(self, im_path):
subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
predictions: List[Dict[str, Any]] = []
image = self.loader.load(im_path, subreddit_tokens) # should be of shape 1, 3, 224, 224
output_dict = self.model(image)
caption = output_dict["predictions"][0] #only one prediction
caption = caption.tolist()
if self.tokenizer.token_to_id("[SEP]") in caption: # this is just the 0 index actually
sos_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
caption[sos_index] = self.tokenizer.token_to_id("::")
caption = self.tokenizer.decode(caption)
# Separate out subreddit from the rest of caption.
if "⁇" in caption: # "⁇" is the token decode equivalent of "::"
subreddit, rest_of_caption = caption.split("⁇")
subreddit = "".join(subreddit.split())
rest_of_caption = rest_of_caption.strip()
else:
subreddit, rest_of_caption = "", caption
return subreddit, rest_of_caption
def load_models():
#download model files
download_files = [CONFIG_PATH, MODEL_PATH]
for f in download_files:
fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
os.system(f"cp {fp} ./{f}")
# load a virtex model
from huggingface_hub import hf_hub_url, cached_download
# #download model files
download_files = [CONFIG_PATH, MODEL_PATH]
for f in download_files:
fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
os.system(f"cp {fp} ./{f}")
#inference on test.jpg
virtexModel = VirTexModel()
subreddit, caption = virtexModel.predict("./test.jpg")
print(subreddit)
print(caption)