zamborg commited on
Commit
a4c3b59
1 Parent(s): 2f0a51c

I think this works

Browse files
app.py CHANGED
@@ -42,11 +42,14 @@ if uploaded_image is None and submitted:
42
  else:
43
  image_file = sample_image if sample_image is not None else random_image
44
 
45
- image = uploaded_image if uploaded_image is not None else Image.open()
46
 
47
  image_dict = imageLoader.transform(image)
48
 
49
- show.image(st.image(image_dict["image"]), "Target Image")
 
 
 
50
 
51
  with st.spinner("Generating Caption"):
52
  subreddit, caption = virtexModel.predict(image_dict)
42
  else:
43
  image_file = sample_image if sample_image is not None else random_image
44
 
45
+ image = uploaded_image if uploaded_image is not None else Image.open(image_file)
46
 
47
  image_dict = imageLoader.transform(image)
48
 
49
+ image = imageLoader.to_image(image_dict["image"].squeeze(0))
50
+
51
+ show = st.image(image)
52
+ show.image(image, "Your Image")
53
 
54
  with st.spinner("Generating Caption"):
55
  subreddit, caption = virtexModel.predict(image_dict)
model.py CHANGED
@@ -30,6 +30,8 @@ class ImageLoader():
30
  def transform(self, image):
31
  im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
32
  return {"image": im}
 
 
33
 
34
  class VirTexModel():
35
  def __init__(self):
30
  def transform(self, image):
31
  im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
32
  return {"image": im}
33
+ def to_image(self, tensor):
34
+ return torchvision.transforms.ToPILImage()(tensor)
35
 
36
  class VirTexModel():
37
  def __init__(self):
virtex/virtex/data/__init__.py CHANGED
@@ -10,7 +10,6 @@ from .datasets.downstream import (
10
  VOC07ClassificationDataset,
11
  ImageDirectoryDataset,
12
  )
13
- from .datasets.redcaps import TarfileDataset
14
 
15
 
16
  __all__ = [
10
  VOC07ClassificationDataset,
11
  ImageDirectoryDataset,
12
  )
 
13
 
14
 
15
  __all__ = [
virtex/virtex/data/datasets/redcaps.py CHANGED
@@ -1,129 +0,0 @@
1
- import glob
2
- import os
3
- import random
4
- from typing import Callable
5
-
6
- import numpy as np
7
- import torch
8
- from torch.utils.data import IterableDataset
9
- import webdataset as wds
10
- import wordsegment as ws
11
-
12
- from virtex.data.tokenizers import SentencePieceBPETokenizer
13
- from virtex.data import transforms as T
14
- import virtex.utils.distributed as dist
15
-
16
- ws.load()
17
-
18
-
19
- class TarfileDataset(IterableDataset):
20
- def __init__(
21
- self,
22
- data_root: str,
23
- batch_size: int,
24
- tokenizer: SentencePieceBPETokenizer,
25
- image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
26
- shuffle_buffer_size: int = 3000, # Set -1 to turn off shuffle.
27
- max_caption_length: int = 50,
28
- ):
29
- super().__init__()
30
-
31
- self.tokenizer = tokenizer
32
- self.image_transform = image_transform
33
- self.max_caption_length = max_caption_length
34
-
35
- self.padding_idx = tokenizer.token_to_id("<unk>")
36
- self.sos_idx = tokenizer.token_to_id("[SOS]")
37
- self.eos_idx = tokenizer.token_to_id("[EOS]")
38
- self.sep_idx = tokenizer.token_to_id("[SEP]")
39
-
40
- # Glob expand all paths in data root.
41
- all_data_paths = []
42
- for dr in data_root.split(" "):
43
- all_data_paths.extend(glob.glob(dr))
44
-
45
- # Deterministic shuffle across GPU process.
46
- all_data_paths = sorted(all_data_paths)
47
- random.Random(0).shuffle(all_data_paths)
48
-
49
- # Shard the data paths as per gpu process.
50
- all_data_paths = all_data_paths[dist.get_rank()::dist.get_world_size()]
51
-
52
- self._dset = (
53
- wds.WebDataset(all_data_paths)
54
- .shuffle(shuffle_buffer_size, initial=shuffle_buffer_size)
55
- .decode("rgb8", handler=wds.warn_and_continue)
56
- .map(self._preprocess)
57
- .batched(batch_size)
58
- )
59
- # Perform word-segmentation of all subreddit names (that's how the
60
- # tokenizer was prepared). Subreddit names can be obtained from
61
- # TAR file names: `{subreddit}_{year}_{index}.tar`.
62
- if "redcaps" in data_root:
63
- self.subreddit_segs = {
64
- sub: " ".join(ws.segment(ws.clean(sub))) for sub in
65
- set([os.path.basename(p).split("_")[0] for p in all_data_paths])
66
- }
67
-
68
- def _preprocess(self, annotation):
69
- image, caption = annotation["jpg"], annotation["json"]["caption"]
70
-
71
- # Transform image-caption pair and convert image from HWC to CHW format.
72
- # Pass in caption to image_transform due to paired horizontal flip.
73
- # Caption won't be tokenized/processed here.
74
- image_caption = self.image_transform(image=image, caption=caption)
75
- image, caption = image_caption["image"], image_caption["caption"]
76
- image = np.transpose(image, (2, 0, 1))
77
-
78
- # Tokenize caption.
79
- _caption_tokens = self.tokenizer.encode(caption)
80
-
81
- # Get subreddit name if it exists, and tokenize it. Only for RedCaps.
82
- if "subreddit" in annotation["json"]:
83
- subreddit = annotation["json"]["subreddit"].lower()
84
- subreddit = self.subreddit_segs[subreddit]
85
-
86
- # Add special [SEP] token after subreddit.
87
- _subreddit_tokens = self.tokenizer.encode(subreddit) + [self.sep_idx]
88
- else:
89
- _subreddit_tokens = []
90
-
91
- # Create forward and backward caption with subreddit token at the start.
92
- caption_tokens = (
93
- [self.sos_idx] + _subreddit_tokens + _caption_tokens + [self.eos_idx]
94
- )[: self.max_caption_length]
95
-
96
- noitpac_tokens = (
97
- [self.eos_idx] + _subreddit_tokens + _caption_tokens[::-1] + [self.sos_idx]
98
- )[: self.max_caption_length]
99
-
100
- return image, caption_tokens, noitpac_tokens, len(caption_tokens)
101
-
102
- def __len__(self):
103
- raise NotImplementedError
104
-
105
- def __iter__(self):
106
-
107
- for batch in iter(self._dset):
108
- # Collate the batch properly here. `image` and `caption_lengths`
109
- # are already tensors.
110
- image, caption_tokens, noitpac_tokens, caption_lengths = batch
111
-
112
- # Pad `caption_tokens` and `masked_labels` up to this length.
113
- caption_tokens = torch.nn.utils.rnn.pad_sequence(
114
- [torch.tensor(c, dtype=torch.long) for c in caption_tokens],
115
- batch_first=True,
116
- padding_value=self.padding_idx,
117
- )
118
- noitpac_tokens = torch.nn.utils.rnn.pad_sequence(
119
- [torch.tensor(c, dtype=torch.long) for c in noitpac_tokens],
120
- batch_first=True,
121
- padding_value=self.padding_idx,
122
- )
123
- caption_lengths = torch.tensor(caption_lengths, dtype=torch.long)
124
- yield {
125
- "image": torch.tensor(image, dtype=torch.float),
126
- "caption_tokens": caption_tokens,
127
- "noitpac_tokens": noitpac_tokens,
128
- "caption_lengths": caption_lengths,
129
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/virtex/factories.py CHANGED
@@ -194,8 +194,6 @@ class PretrainingDatasetFactory(Factory):
194
  "masked_lm": vdata.MaskedLmDataset,
195
  "token_classification": vdata.TokenClassificationDataset,
196
  "multilabel_classification": vdata.MultiLabelClassificationDataset,
197
- "virtex_web": vdata.TarfileDataset,
198
- "miniclip_web": vdata.TarfileDataset,
199
  }
200
 
201
  @classmethod
194
  "masked_lm": vdata.MaskedLmDataset,
195
  "token_classification": vdata.TokenClassificationDataset,
196
  "multilabel_classification": vdata.MultiLabelClassificationDataset,
 
 
197
  }
198
 
199
  @classmethod