zamborg commited on
Commit
49c0315
β€’
1 Parent(s): a5f8a35
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +96 -2
  3. samples/test.jpg +0 -0
.gitignore CHANGED
@@ -1 +1,4 @@
1
  .ipynb_checkpoints/*
 
 
 
1
  .ipynb_checkpoints/*
2
+ *.pth
3
+ *.yaml
4
+ *ipynb_checkpoints
app.py CHANGED
@@ -1,4 +1,98 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider("Select a value")
4
- st.write(x, "squared is", x * x)
1
  import streamlit as st
2
+ from huggingface_hub import snapshot_download
3
+ from PIL import Image
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ from typing import Any, Dict, List
9
+
10
+ from loguru import logger
11
+ import torch
12
+ import torchvision
13
+ from torch.utils.data import DataLoader
14
+ from tqdm import tqdm
15
+
16
+ import wordsegment as ws
17
+
18
+ from virtex.config import Config
19
+ from virtex.data import ImageDirectoryDataset
20
+ from virtex.factories import TokenizerFactory, PretrainingModelFactory
21
+ from virtex.utils.checkpointing import CheckpointManager
22
+ from virtex.utils.common import common_parser
23
+
24
+ CONFIG_PATH = "config.yaml"
25
+ MODEL_PATH = "checkpoint_last5.pth"
26
+
27
+ # x = st.slider("Select a value")
28
+ # st.write(x, "squared is", x * x)
29
+
30
+
31
+
32
+ class ImageLoader():
33
+ def __init__(self):
34
+ self.transformer = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
35
+ torchvision.transforms.CenterCrop(224),
36
+ torchvision.transforms.ToTensor()])
37
+ def load(self, im_path, prompt):
38
+ im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
39
+ return {"image": im, "decode_prompt": prompt}
40
+
41
+ class VirTexModel():
42
+ def __init__(self):
43
+ self.config = Config(CONFIG_PATH)
44
+ ws.load()
45
+ self.device = 'cpu'
46
+ self.tokenizer = TokenizerFactory.from_config(self.config)
47
+ self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
48
+ CheckpointManager(model=self.model).load("./checkpoint_last5.pth")
49
+ self.model.eval()
50
+ self.loader = ImageLoader()
51
+
52
+ def predict(self, im_path):
53
+ subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
54
+ predictions: List[Dict[str, Any]] = []
55
+ image = self.loader.load(im_path, subreddit_tokens) # should be of shape 1, 3, 224, 224
56
+ output_dict = self.model(image)
57
+ caption = output_dict["predictions"][0] #only one prediction
58
+ caption = caption.tolist()
59
+ if self.tokenizer.token_to_id("[SEP]") in caption: # this is just the 0 index actually
60
+ sos_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
61
+ caption[sos_index] = self.tokenizer.token_to_id("::")
62
+
63
+ caption = self.tokenizer.decode(caption)
64
+
65
+ # Separate out subreddit from the rest of caption.
66
+ if "⁇" in caption: # "⁇" is the token decode equivalent of "::"
67
+ subreddit, rest_of_caption = caption.split("⁇")
68
+ subreddit = "".join(subreddit.split())
69
+ rest_of_caption = rest_of_caption.strip()
70
+ else:
71
+ subreddit, rest_of_caption = "", caption
72
+
73
+ return subreddit, rest_of_caption
74
+
75
+ def load_models():
76
+ #download model files
77
+ download_files = [CONFIG_PATH, MODEL_PATH]
78
+ for f in download_files:
79
+ fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
80
+ os.system(f"cp {fp} ./{f}")
81
+
82
+
83
+
84
+ # load a virtex model
85
+ from huggingface_hub import hf_hub_url, cached_download
86
+
87
+ # #download model files
88
+ download_files = [CONFIG_PATH, MODEL_PATH]
89
+ for f in download_files:
90
+ fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
91
+ os.system(f"cp {fp} ./{f}")
92
+
93
+ #inference on test.jpg
94
+ virtexModel = VirTexModel()
95
+ subreddit, caption = virtexModel.predict("./test.jpg")
96
+ print(subreddit)
97
+ print(caption)
98
 
 
 
samples/test.jpg ADDED