import streamlit as st from huggingface_hub import hf_hub_url, cached_download from PIL import Image import os import json import glob import random from typing import Any, Dict, List import torch import torchvision import wordsegment as ws from virtex.config import Config from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory from virtex.utils.checkpointing import CheckpointManager CONFIG_PATH = "config.yaml" MODEL_PATH = "checkpoint_last5.pth" VALID_SUBREDDITS_PATH = "subreddit_list.json" SAMPLES_PATH = "./samples/*.jpg" class ImageLoader(): def __init__(self): self.image_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))]) self.show_size=500 def load(self, im_path): im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0) return {"image": im} def raw_load(self, im_path): im = torch.FloatTensor(Image.open(im_path)) return {"image": im} def transform(self, image): im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0) return {"image": im} def text_transform(self, text): # at present just lowercasing: return text.lower() def show_resize(self, image): # ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol image = torchvision.transforms.functional.to_tensor(image) x,y = image.shape[-2:] ratio = float(self.show_size/max((x,y))) image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)]) return torchvision.transforms.functional.to_pil_image(image) class VirTexModel(): def __init__(self): self.config = Config(CONFIG_PATH) ws.load() self.device = 'cpu' self.tokenizer = TokenizerFactory.from_config(self.config) self.model = PretrainingModelFactory.from_config(self.config).to(self.device) CheckpointManager(model=self.model).load("./checkpoint_last5.pth") self.model.eval() self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH)) def predict(self, image_dict, sub_prompt = None, prompt = ""): if sub_prompt is None: subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long() else: subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt))) subreddit_tokens = ( [self.model.sos_index] + self.tokenizer.encode(subreddit_tokens) + [self.tokenizer.token_to_id("[SEP]")] ) subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long() if prompt is not "": # at present prompts without subreddits will break without this change # TODO FIX cap_tokens = self.tokenizer.encode(prompt) cap_tokens = torch.tensor(cap_tokens, device=self.device).long() subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor( ( [self.model.sos_index] + self.tokenizer.encode("pics") + [self.tokenizer.token_to_id("[SEP]")] ), device = self.device).long() subreddit_tokens = torch.cat( [ subreddit_tokens, torch.tensor([self.tokenizer.token_to_id("[SEP]")], device=self.device).long(), cap_tokens ]) predictions: List[Dict[str, Any]] = [] is_valid_subreddit = False subreddit, rest_of_caption = "", "" image_dict["decode_prompt"] = subreddit_tokens while not is_valid_subreddit: with torch.no_grad(): caption = self.model(image_dict)["predictions"][0].tolist() if self.tokenizer.token_to_id("[SEP]") in caption: sep_index = caption.index(self.tokenizer.token_to_id("[SEP]")) caption[sep_index] = self.tokenizer.token_to_id("://") caption = self.tokenizer.decode(caption) if "://" in caption: subreddit, rest_of_caption = caption.split("://") subreddit = "".join(subreddit.split()) rest_of_caption = rest_of_caption.strip() else: subreddit, rest_of_caption = "", caption is_valid_subreddit = subreddit in self.valid_subs return subreddit, rest_of_caption def download_files(): #download model files download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH] for f in download_files: fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f)) os.system(f"cp {fp} ./{f}") def get_samples(): return glob.glob(SAMPLES_PATH) def get_rand_idx(samples): return random.randint(0,len(samples)-1) @st.cache def create_objects(): sample_images = get_samples() virtexModel = VirTexModel() imageLoader = ImageLoader() valid_subs = json.load(open(VALID_SUBREDDITS_PATH)) valid_subs.insert(0, None) return virtexModel, imageLoader, sample_images, valid_subs