zamborg commited on
Commit
7d1df38
·
1 Parent(s): 5281471
Files changed (3) hide show
  1. .gitignore +1 -0
  2. model.py +89 -0
  3. requirements.txt +19 -0
.gitignore CHANGED
@@ -3,3 +3,4 @@
3
  *.yaml
4
  *ipynb_checkpoints
5
  __pycache__
 
 
3
  *.yaml
4
  *ipynb_checkpoints
5
  __pycache__
6
+ *.json
model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_url, cached_download
2
+ from PIL import Image
3
+ import os
4
+ import json
5
+ import glob
6
+ import random
7
+ from typing import Any, Dict, List
8
+ import torch
9
+ import torchvision
10
+
11
+ import wordsegment as ws
12
+
13
+ from virtex.config import Config
14
+ from virtex.factories import TokenizerFactory, PretrainingModelFactory
15
+ from virtex.utils.checkpointing import CheckpointManager
16
+
17
+ CONFIG_PATH = "config.yaml"
18
+ MODEL_PATH = "checkpoint_last5.pth"
19
+ VALID_SUBREDDITS_PATH = "subreddit_list.json"
20
+ SAMPLES_PATH = "./samples/*.jpg"
21
+
22
+ class ImageLoader():
23
+ def __init__(self):
24
+ self.transformer = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
25
+ torchvision.transforms.CenterCrop(224),
26
+ torchvision.transforms.ToTensor()])
27
+ def load(self, im_path, prompt = ""):
28
+ im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
29
+ return {"image": im, "decode_prompt": prompt}
30
+ def transform(self, image, prompt = ""):
31
+ im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
32
+ return {"image": im, "decode_prompt": prompt}
33
+
34
+ class VirTexModel():
35
+ def __init__(self):
36
+ self.config = Config(CONFIG_PATH)
37
+ ws.load()
38
+ self.device = 'cpu'
39
+ self.tokenizer = TokenizerFactory.from_config(self.config)
40
+ self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
41
+ CheckpointManager(model=self.model).load("./checkpoint_last5.pth")
42
+ self.model.eval()
43
+ self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
44
+
45
+ def predict(self, image_dict, sub_prompt = None, prompt = ""):
46
+ if sub_prompt is None:
47
+ subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
48
+ else:
49
+ subreddit_tokens = torch.tensor([self.tokenizer.token_to_id(sub_prompt)], device=self.device).long()
50
+ predictions: List[Dict[str, Any]] = []
51
+
52
+ is_valid_subreddit = False
53
+ subreddit, rest_of_caption = "", ""
54
+ while not is_valid_subreddit:
55
+
56
+ with torch.no_grad():
57
+ caption = self.model(image_dict)["predictions"][0].tolist()
58
+
59
+ if self.tokenizer.token_to_id("[SEP]") in caption:
60
+ sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
61
+ caption[sep_index] = self.tokenizer.token_to_id("://")
62
+
63
+ caption = self.tokenizer.decode(caption)
64
+
65
+ if "://" in caption:
66
+ subreddit, rest_of_caption = caption.split("://")
67
+ subreddit = "".join(subreddit.split())
68
+ rest_of_caption = rest_of_caption.strip()
69
+ else:
70
+ subreddit, rest_of_caption = "", caption
71
+
72
+ is_valid_subreddit = True if sub_prompt is not None else subreddit in self.valid_subs
73
+
74
+
75
+ return subreddit, rest_of_caption
76
+
77
+ def download_files():
78
+ #download model files
79
+ download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
80
+ for f in download_files:
81
+ fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
82
+ os.system(f"cp {fp} ./{f}")
83
+
84
+ def get_samples():
85
+ return glob.glob(SAMPLES_PATH)
86
+
87
+ def get_rand_img(samples):
88
+ return samples[random.randint(0,len(samples)-1)]
89
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations>=0.5.0
2
+ Cython>=0.25
3
+ ftfy==5.8
4
+ future==0.18.0
5
+ huggingface-hub==0.1.2
6
+ lmdb==0.97
7
+ loguru==0.3.2
8
+ mypy_extensions==0.4.1
9
+ lvis==0.5.3
10
+ numpy>=1.17
11
+ opencv-python==4.1.2.30
12
+ scikit-learn==0.21.3
13
+ sentencepiece>=0.1.90
14
+ torch==1.7.0
15
+ torchvision==0.8
16
+ tqdm>=4.50.0
17
+ wordsegment==1.3.1
18
+ git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
19
+ git+git://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI