aroraaman commited on
Commit
453e23a
1 Parent(s): f8cbcea

Add `app.py`

Browse files
Files changed (2) hide show
  1. app.py +183 -0
  2. requirements.txt +26 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import gradio as gr
7
+ from tokenizers import Tokenizer
8
+ from torch.utils.data import Dataset
9
+ import albumentations as A
10
+ from tqdm import tqdm
11
+
12
+ from fourm.vq.vqvae import VQVAE
13
+ from fourm.models.fm import FM
14
+ from fourm.models.generate import (
15
+ GenerationSampler,
16
+ build_chained_generation_schedules,
17
+ init_empty_target_modality,
18
+ custom_text,
19
+ )
20
+ from fourm.utils.plotting_utils import decode_dict
21
+ from fourm.data.modality_info import MODALITY_INFO
22
+ from fourm.data.modality_transforms import RGBTransform
23
+ from torchvision.transforms.functional import center_crop
24
+
25
+ # Constants and configurations
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ IMG_SIZE = 224
28
+ TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"
29
+ FM_MODEL_PATH = "EPFL-VILAB/4M-21_L"
30
+ VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224"
31
+ IMAGE_DATASET_PATH = "/home/ubuntu/GIT_REPOS/ml-4m/data/custom_data/"
32
+
33
+ # Load models
34
+ text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
35
+ vqvae = VQVAE.from_pretrained(VQVAE_PATH)
36
+ fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE)
37
+
38
+ # Generation configurations
39
+ cond_domains = ["caption", "metadata"]
40
+ target_domains = ["tok_dinov2_global"]
41
+ tokens_per_target = [16]
42
+ generation_config = {
43
+ "autoregression_schemes": ["roar"],
44
+ "decoding_steps": [1],
45
+ "token_decoding_schedules": ["linear"],
46
+ "temps": [2.0],
47
+ "temp_schedules": ["onex:0.5:0.5"],
48
+ "cfg_scales": [1.0],
49
+ "cfg_schedules": ["constant"],
50
+ "cfg_grow_conditioning": True,
51
+ }
52
+ top_p, top_k = 0.8, 0.0
53
+
54
+ schedule = build_chained_generation_schedules(
55
+ cond_domains=cond_domains,
56
+ target_domains=target_domains,
57
+ tokens_per_target=tokens_per_target,
58
+ **generation_config,
59
+ )
60
+
61
+ sampler = GenerationSampler(fm_model)
62
+
63
+
64
+ class ImageDataset(Dataset):
65
+ def __init__(self, path: str, img_sz=IMG_SIZE):
66
+ self.path = Path(path)
67
+ self.files = list(self.path.rglob("*"))
68
+ self.tfms = A.Compose(
69
+ [A.SmallestMaxSize(img_sz)])
70
+
71
+ def __len__(self):
72
+ return len(self.files)
73
+
74
+ def __getitem__(self, idx):
75
+ img = Image.open(self.files[idx]).convert("RGB")
76
+ img = np.array(img)
77
+ img = self.tfms(image=img)["image"]
78
+ return Image.fromarray(img)
79
+
80
+
81
+ dataset = ImageDataset(IMAGE_DATASET_PATH)
82
+
83
+
84
+ @torch.no_grad()
85
+ def get_image_embeddings(dataset):
86
+ cache_file = "image_emb.pt"
87
+ if os.path.exists(cache_file):
88
+ return torch.load(cache_file)
89
+
90
+
91
+ image_embeddings = get_image_embeddings(dataset).to(DEVICE)
92
+ print(image_embeddings.shape)
93
+
94
+ def get_similar_images(caption, brightness, num_items):
95
+ batched_sample = {}
96
+
97
+ for target_mod, ntoks in zip(target_domains, tokens_per_target):
98
+ batched_sample = init_empty_target_modality(
99
+ batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE
100
+ )
101
+
102
+ metadata = f"v1=6 v0={num_items} v1=10 v0={brightness}"
103
+ print(metadata)
104
+ batched_sample = custom_text(
105
+ batched_sample,
106
+ input_text=caption,
107
+ eos_token="[EOS]",
108
+ key="caption",
109
+ device=DEVICE,
110
+ text_tokenizer=text_tokenizer,
111
+ )
112
+ batched_sample = custom_text(
113
+ batched_sample,
114
+ input_text=metadata,
115
+ eos_token="[EOS]",
116
+ key="metadata",
117
+ device=DEVICE,
118
+ text_tokenizer=text_tokenizer,
119
+ )
120
+
121
+ out_dict = sampler.generate(
122
+ batched_sample,
123
+ schedule,
124
+ text_tokenizer=text_tokenizer,
125
+ verbose=True,
126
+ seed=0,
127
+ top_p=top_p,
128
+ top_k=top_k,
129
+ )
130
+
131
+ with torch.no_grad():
132
+ dec_dict = decode_dict(
133
+ out_dict,
134
+ {"tok_dinov2_global": vqvae.to(DEVICE)},
135
+ text_tokenizer,
136
+ image_size=IMG_SIZE,
137
+ patch_size=16,
138
+ decoding_steps=1,
139
+ )
140
+
141
+ combined_features = dec_dict["tok_dinov2_global"]
142
+ similarities = torch.nn.functional.cosine_similarity(
143
+ combined_features, image_embeddings
144
+ )
145
+ top_indices = similarities.argsort(descending=True)[:1]
146
+ print(top_indices, similarities[top_indices])
147
+ return [dataset[i] for i in top_indices.cpu().numpy()]
148
+
149
+
150
+ # Gradio interface
151
+ with gr.Blocks() as demo:
152
+ gr.Markdown("# Image Retrieval using 4M-21: An Any-to-Any Vision Model")
153
+ with gr.Row():
154
+ with gr.Column(scale=1):
155
+ caption = gr.Textbox(
156
+ label="Caption Description", placeholder="Enter image description..."
157
+ )
158
+ brightness = gr.Slider(
159
+ minimum=0, maximum=255, value=5, step=1,
160
+ label="Brightness", info="Adjust image brightness (0-255)"
161
+ )
162
+ num_items = gr.Slider(
163
+ minimum=0, maximum=50, value=5, step=1,
164
+ label="Number of Items", info="Number of COCO instances in image (0-50)"
165
+ )
166
+ with gr.Column(scale=1):
167
+ output_images = gr.Gallery(
168
+ label="Retrieved Images",
169
+ show_label=True,
170
+ elem_id="gallery",
171
+ columns=2,
172
+ rows=2,
173
+ height=512,
174
+ )
175
+ submit_btn = gr.Button("Retrieve Most Similar Image")
176
+ submit_btn.click(
177
+ fn=get_similar_images,
178
+ inputs=[caption, brightness, num_items],
179
+ outputs=output_images,
180
+ )
181
+
182
+ if __name__ == "__main__":
183
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ torchvision>=0.16.0
3
+ albumentations>=1.4.0
4
+ boto3>=1.26.16
5
+ braceexpand>=0.1.7
6
+ diffusers==0.20.0
7
+ einops>=0.7.0
8
+ ftfy==6.1.0
9
+ huggingface_hub>=0.20.0
10
+ matplotlib>=3.6.2
11
+ numpy>=1.26.4
12
+ opencv-python>=4.9.0.80
13
+ opencv-python-headless>=4.6.0.66
14
+ pandas>=1.5.2
15
+ Pillow>=9.3.0
16
+ PyYAML>=6.0
17
+ regex>=2022.10.31
18
+ requests>=2.31.0
19
+ scikit-learn>=1.1.3
20
+ setuptools>=61.0
21
+ tokenizers>=0.15.2
22
+ datasets>=0.17
23
+ torchmetrics>=1.3.1
24
+ tqdm>=4.64.1
25
+ wandb>=0.13.5
26
+ webdataset>=0.2.86