import streamlit as st from huggingface_hub import snapshot_download from PIL import Image import argparse import json import os from typing import Any, Dict, List from loguru import logger import torch import torchvision from torch.utils.data import DataLoader from tqdm import tqdm import wordsegment as ws from virtex.config import Config from virtex.data import ImageDirectoryDataset from virtex.factories import TokenizerFactory, PretrainingModelFactory from virtex.utils.checkpointing import CheckpointManager from virtex.utils.common import common_parser CONFIG_PATH = "config.yaml" MODEL_PATH = "checkpoint_last5.pth" # x = st.slider("Select a value") # st.write(x, "squared is", x * x) class ImageLoader(): def __init__(self): self.transformer = torchvision.transforms.Compose([torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor()]) def load(self, im_path, prompt): im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0) return {"image": im, "decode_prompt": prompt} class VirTexModel(): def __init__(self): self.config = Config(CONFIG_PATH) ws.load() self.device = 'cpu' self.tokenizer = TokenizerFactory.from_config(self.config) self.model = PretrainingModelFactory.from_config(self.config).to(self.device) CheckpointManager(model=self.model).load("./checkpoint_last5.pth") self.model.eval() self.loader = ImageLoader() def predict(self, im_path): subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long() predictions: List[Dict[str, Any]] = [] image = self.loader.load(im_path, subreddit_tokens) # should be of shape 1, 3, 224, 224 output_dict = self.model(image) caption = output_dict["predictions"][0] #only one prediction caption = caption.tolist() if self.tokenizer.token_to_id("[SEP]") in caption: # this is just the 0 index actually sos_index = caption.index(self.tokenizer.token_to_id("[SEP]")) caption[sos_index] = self.tokenizer.token_to_id("::") caption = self.tokenizer.decode(caption) # Separate out subreddit from the rest of caption. if "⁇" in caption: # "⁇" is the token decode equivalent of "::" subreddit, rest_of_caption = caption.split("⁇") subreddit = "".join(subreddit.split()) rest_of_caption = rest_of_caption.strip() else: subreddit, rest_of_caption = "", caption return subreddit, rest_of_caption def load_models(): #download model files download_files = [CONFIG_PATH, MODEL_PATH] for f in download_files: fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f)) os.system(f"cp {fp} ./{f}") # load a virtex model from huggingface_hub import hf_hub_url, cached_download # #download model files download_files = [CONFIG_PATH, MODEL_PATH] for f in download_files: fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f)) os.system(f"cp {fp} ./{f}") #inference on test.jpg virtexModel = VirTexModel() subreddit, caption = virtexModel.predict("./test.jpg") print(subreddit) print(caption)