Geonmo commited on
Commit
f7c2a85
1 Parent(s): 088cc0d

update app.py

Browse files
Files changed (3) hide show
  1. .gitignore +4 -0
  2. app.py +15 -16
  3. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ *.swp
3
+ hf_models/
4
+ pretrained_models/
app.py CHANGED
@@ -6,11 +6,12 @@ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
6
  import os
7
  import time
8
  from argparse import ArgumentParser
 
9
 
10
  import numpy as np
11
  import torch
12
  import gradio as gr
13
- from clip_retrieval.clip_client import ClipClient
14
 
15
  from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
16
  from models import build_text_encoder, Phi, PIC2WORD
@@ -19,6 +20,7 @@ import transformers
19
  from huggingface_hub import hf_hub_url, cached_download
20
 
21
 
 
22
  def parse_args():
23
  parser = ArgumentParser()
24
  parser.add_argument("--lincir_ckpt_path", default=None, type=str,
@@ -100,6 +102,7 @@ def load_models(args):
100
  }
101
 
102
 
 
103
  def predict(images, input_text, model_name):
104
  start_time = time.time()
105
  input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
@@ -125,18 +128,15 @@ def predict(images, input_text, model_name):
125
  clip_text_time = time.time() - start_time
126
 
127
  start_time = time.time()
128
- try:
129
- results = client.query(embedding_input=text_embeddings[0].tolist())
130
- output = ''
131
- except:
132
- results = []
133
- output = 'The server for image retrieval is not working. Please try again later.'
134
- retrieval_time = time.time() - start_time
135
 
 
 
 
136
 
 
137
 
138
- for idx, result in enumerate(results):
139
- image_url = result['url']
140
  output += f'![image]({image_url})\n'
141
 
142
  time_output = {'CLIP visual extractor': clip_image_time,
@@ -180,7 +180,7 @@ def test_fps(batch_size=1):
180
  if __name__ == '__main__':
181
  args = parse_args()
182
 
183
- global model_dict, client
184
 
185
  model_dict = load_models(args)
186
 
@@ -189,19 +189,18 @@ if __name__ == '__main__':
189
  test_fps(1)
190
  exit()
191
 
 
192
 
193
- client = ClipClient(url="https://knn.laion.ai/knn-service",
194
- indice_name="laion5B-H-14" if args.clip_model_name == "huge" else "laion5B-L-14",
195
- )
196
 
197
- title = 'Zeroshot CIR demo'
198
 
199
  md_title = f'''# {title}
200
  [LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval
201
  [SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion
202
  [Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval
203
 
204
- K-NN index for the retrieval results are entirely trained using the entire Laion-5B imageset. This is made possible thanks to the great work of [rom1504](https://github.com/rom1504/clip-retrieval).
205
  '''
206
 
207
  with gr.Blocks(title=title) as demo:
 
6
  import os
7
  import time
8
  from argparse import ArgumentParser
9
+ import json
10
 
11
  import numpy as np
12
  import torch
13
  import gradio as gr
14
+ import faiss
15
 
16
  from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
17
  from models import build_text_encoder, Phi, PIC2WORD
 
20
  from huggingface_hub import hf_hub_url, cached_download
21
 
22
 
23
+
24
  def parse_args():
25
  parser = ArgumentParser()
26
  parser.add_argument("--lincir_ckpt_path", default=None, type=str,
 
102
  }
103
 
104
 
105
+ @torch.no_grad()
106
  def predict(images, input_text, model_name):
107
  start_time = time.time()
108
  input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
 
128
  clip_text_time = time.time() - start_time
129
 
130
  start_time = time.time()
 
 
 
 
 
 
 
131
 
132
+ _, results = faiss_index.search(text_embeddings.cpu().numpy(), k=10)
133
+
134
+ retrieval_time = time.time() - start_time
135
 
136
+ output = ''
137
 
138
+ for idx, retrieved_idx in enumerate(results[0]):
139
+ image_url = image_urls[retrieved_idx]
140
  output += f'![image]({image_url})\n'
141
 
142
  time_output = {'CLIP visual extractor': clip_image_time,
 
180
  if __name__ == '__main__':
181
  args = parse_args()
182
 
183
+ global model_dict, faiss_index, image_urls
184
 
185
  model_dict = load_models(args)
186
 
 
189
  test_fps(1)
190
  exit()
191
 
192
+ faiss_index = faiss.read_index('./clip_large.index', faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
193
 
194
+ image_urls = json.load(open('./image_urls.json'))
 
 
195
 
196
+ title = 'Zeroshot CIR demo to search high-quality AI images'
197
 
198
  md_title = f'''# {title}
199
  [LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval
200
  [SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion
201
  [Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval
202
 
203
+ K-NN index for the retrieval results are entirely trained using [the upscaled midjourney v5 images (444,901)](https://huggingface.co/datasets/wanng/midjourney-v5-202304-clean).
204
  '''
205
 
206
  with gr.Blocks(title=title) as demo:
requirements.txt CHANGED
@@ -6,3 +6,4 @@ accelerate
6
  datasets
7
  spacy
8
  git+https://github.com/rom1504/clip-retrieval
 
 
6
  datasets
7
  spacy
8
  git+https://github.com/rom1504/clip-retrieval
9
+ faiss