Spaces:
Runtime error
Runtime error
Add `app.py`
Browse files- app.py +183 -0
- requirements.txt +26 -0
app.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import gradio as gr
|
7 |
+
from tokenizers import Tokenizer
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import albumentations as A
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from fourm.vq.vqvae import VQVAE
|
13 |
+
from fourm.models.fm import FM
|
14 |
+
from fourm.models.generate import (
|
15 |
+
GenerationSampler,
|
16 |
+
build_chained_generation_schedules,
|
17 |
+
init_empty_target_modality,
|
18 |
+
custom_text,
|
19 |
+
)
|
20 |
+
from fourm.utils.plotting_utils import decode_dict
|
21 |
+
from fourm.data.modality_info import MODALITY_INFO
|
22 |
+
from fourm.data.modality_transforms import RGBTransform
|
23 |
+
from torchvision.transforms.functional import center_crop
|
24 |
+
|
25 |
+
# Constants and configurations
|
26 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
+
IMG_SIZE = 224
|
28 |
+
TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"
|
29 |
+
FM_MODEL_PATH = "EPFL-VILAB/4M-21_L"
|
30 |
+
VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224"
|
31 |
+
IMAGE_DATASET_PATH = "/home/ubuntu/GIT_REPOS/ml-4m/data/custom_data/"
|
32 |
+
|
33 |
+
# Load models
|
34 |
+
text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
|
35 |
+
vqvae = VQVAE.from_pretrained(VQVAE_PATH)
|
36 |
+
fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE)
|
37 |
+
|
38 |
+
# Generation configurations
|
39 |
+
cond_domains = ["caption", "metadata"]
|
40 |
+
target_domains = ["tok_dinov2_global"]
|
41 |
+
tokens_per_target = [16]
|
42 |
+
generation_config = {
|
43 |
+
"autoregression_schemes": ["roar"],
|
44 |
+
"decoding_steps": [1],
|
45 |
+
"token_decoding_schedules": ["linear"],
|
46 |
+
"temps": [2.0],
|
47 |
+
"temp_schedules": ["onex:0.5:0.5"],
|
48 |
+
"cfg_scales": [1.0],
|
49 |
+
"cfg_schedules": ["constant"],
|
50 |
+
"cfg_grow_conditioning": True,
|
51 |
+
}
|
52 |
+
top_p, top_k = 0.8, 0.0
|
53 |
+
|
54 |
+
schedule = build_chained_generation_schedules(
|
55 |
+
cond_domains=cond_domains,
|
56 |
+
target_domains=target_domains,
|
57 |
+
tokens_per_target=tokens_per_target,
|
58 |
+
**generation_config,
|
59 |
+
)
|
60 |
+
|
61 |
+
sampler = GenerationSampler(fm_model)
|
62 |
+
|
63 |
+
|
64 |
+
class ImageDataset(Dataset):
|
65 |
+
def __init__(self, path: str, img_sz=IMG_SIZE):
|
66 |
+
self.path = Path(path)
|
67 |
+
self.files = list(self.path.rglob("*"))
|
68 |
+
self.tfms = A.Compose(
|
69 |
+
[A.SmallestMaxSize(img_sz)])
|
70 |
+
|
71 |
+
def __len__(self):
|
72 |
+
return len(self.files)
|
73 |
+
|
74 |
+
def __getitem__(self, idx):
|
75 |
+
img = Image.open(self.files[idx]).convert("RGB")
|
76 |
+
img = np.array(img)
|
77 |
+
img = self.tfms(image=img)["image"]
|
78 |
+
return Image.fromarray(img)
|
79 |
+
|
80 |
+
|
81 |
+
dataset = ImageDataset(IMAGE_DATASET_PATH)
|
82 |
+
|
83 |
+
|
84 |
+
@torch.no_grad()
|
85 |
+
def get_image_embeddings(dataset):
|
86 |
+
cache_file = "image_emb.pt"
|
87 |
+
if os.path.exists(cache_file):
|
88 |
+
return torch.load(cache_file)
|
89 |
+
|
90 |
+
|
91 |
+
image_embeddings = get_image_embeddings(dataset).to(DEVICE)
|
92 |
+
print(image_embeddings.shape)
|
93 |
+
|
94 |
+
def get_similar_images(caption, brightness, num_items):
|
95 |
+
batched_sample = {}
|
96 |
+
|
97 |
+
for target_mod, ntoks in zip(target_domains, tokens_per_target):
|
98 |
+
batched_sample = init_empty_target_modality(
|
99 |
+
batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE
|
100 |
+
)
|
101 |
+
|
102 |
+
metadata = f"v1=6 v0={num_items} v1=10 v0={brightness}"
|
103 |
+
print(metadata)
|
104 |
+
batched_sample = custom_text(
|
105 |
+
batched_sample,
|
106 |
+
input_text=caption,
|
107 |
+
eos_token="[EOS]",
|
108 |
+
key="caption",
|
109 |
+
device=DEVICE,
|
110 |
+
text_tokenizer=text_tokenizer,
|
111 |
+
)
|
112 |
+
batched_sample = custom_text(
|
113 |
+
batched_sample,
|
114 |
+
input_text=metadata,
|
115 |
+
eos_token="[EOS]",
|
116 |
+
key="metadata",
|
117 |
+
device=DEVICE,
|
118 |
+
text_tokenizer=text_tokenizer,
|
119 |
+
)
|
120 |
+
|
121 |
+
out_dict = sampler.generate(
|
122 |
+
batched_sample,
|
123 |
+
schedule,
|
124 |
+
text_tokenizer=text_tokenizer,
|
125 |
+
verbose=True,
|
126 |
+
seed=0,
|
127 |
+
top_p=top_p,
|
128 |
+
top_k=top_k,
|
129 |
+
)
|
130 |
+
|
131 |
+
with torch.no_grad():
|
132 |
+
dec_dict = decode_dict(
|
133 |
+
out_dict,
|
134 |
+
{"tok_dinov2_global": vqvae.to(DEVICE)},
|
135 |
+
text_tokenizer,
|
136 |
+
image_size=IMG_SIZE,
|
137 |
+
patch_size=16,
|
138 |
+
decoding_steps=1,
|
139 |
+
)
|
140 |
+
|
141 |
+
combined_features = dec_dict["tok_dinov2_global"]
|
142 |
+
similarities = torch.nn.functional.cosine_similarity(
|
143 |
+
combined_features, image_embeddings
|
144 |
+
)
|
145 |
+
top_indices = similarities.argsort(descending=True)[:1]
|
146 |
+
print(top_indices, similarities[top_indices])
|
147 |
+
return [dataset[i] for i in top_indices.cpu().numpy()]
|
148 |
+
|
149 |
+
|
150 |
+
# Gradio interface
|
151 |
+
with gr.Blocks() as demo:
|
152 |
+
gr.Markdown("# Image Retrieval using 4M-21: An Any-to-Any Vision Model")
|
153 |
+
with gr.Row():
|
154 |
+
with gr.Column(scale=1):
|
155 |
+
caption = gr.Textbox(
|
156 |
+
label="Caption Description", placeholder="Enter image description..."
|
157 |
+
)
|
158 |
+
brightness = gr.Slider(
|
159 |
+
minimum=0, maximum=255, value=5, step=1,
|
160 |
+
label="Brightness", info="Adjust image brightness (0-255)"
|
161 |
+
)
|
162 |
+
num_items = gr.Slider(
|
163 |
+
minimum=0, maximum=50, value=5, step=1,
|
164 |
+
label="Number of Items", info="Number of COCO instances in image (0-50)"
|
165 |
+
)
|
166 |
+
with gr.Column(scale=1):
|
167 |
+
output_images = gr.Gallery(
|
168 |
+
label="Retrieved Images",
|
169 |
+
show_label=True,
|
170 |
+
elem_id="gallery",
|
171 |
+
columns=2,
|
172 |
+
rows=2,
|
173 |
+
height=512,
|
174 |
+
)
|
175 |
+
submit_btn = gr.Button("Retrieve Most Similar Image")
|
176 |
+
submit_btn.click(
|
177 |
+
fn=get_similar_images,
|
178 |
+
inputs=[caption, brightness, num_items],
|
179 |
+
outputs=output_images,
|
180 |
+
)
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.1.0
|
2 |
+
torchvision>=0.16.0
|
3 |
+
albumentations>=1.4.0
|
4 |
+
boto3>=1.26.16
|
5 |
+
braceexpand>=0.1.7
|
6 |
+
diffusers==0.20.0
|
7 |
+
einops>=0.7.0
|
8 |
+
ftfy==6.1.0
|
9 |
+
huggingface_hub>=0.20.0
|
10 |
+
matplotlib>=3.6.2
|
11 |
+
numpy>=1.26.4
|
12 |
+
opencv-python>=4.9.0.80
|
13 |
+
opencv-python-headless>=4.6.0.66
|
14 |
+
pandas>=1.5.2
|
15 |
+
Pillow>=9.3.0
|
16 |
+
PyYAML>=6.0
|
17 |
+
regex>=2022.10.31
|
18 |
+
requests>=2.31.0
|
19 |
+
scikit-learn>=1.1.3
|
20 |
+
setuptools>=61.0
|
21 |
+
tokenizers>=0.15.2
|
22 |
+
datasets>=0.17
|
23 |
+
torchmetrics>=1.3.1
|
24 |
+
tqdm>=4.64.1
|
25 |
+
wandb>=0.13.5
|
26 |
+
webdataset>=0.2.86
|