shivangibithel commited on
Commit
e3bc95e
1 Parent(s): 8c1e611

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -9,34 +9,29 @@ from PIL import Image
9
  from io import BytesIO
10
  from sentence_transformers import SentenceTransformer
11
 
12
- # dataset = load_dataset("nlphuji/flickr30k", streaming=True)
13
- # df = pd.DataFrame.from_dict(dataset["train"])
14
 
15
  # Load the pre-trained sentence encoder
16
- model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = SentenceTransformer(model_name)
19
 
20
- # # Load the pre-trained image model
21
- # image_model_name = 'image_model.ckpt'
22
- # image_model_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/images/vqvae.png'
23
- # wget.download(image_model_url, image_model_name)
24
- # image_model = torch.load(image_model_name, map_location=torch.device('cpu'))
25
- # image_model.eval()
26
-
27
  # Load the FAISS index
28
  index_name = 'index.faiss'
29
- index_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/faiss_files/faiss.index'
30
  wget.download(index_url, index_name)
31
  index = faiss.read_index(index_name)
32
 
33
  # Map the image ids to the corresponding image URLs
34
- image_map_name = 'image_map.json'
35
- image_map_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/faiss_files/image_map.json'
36
  wget.download(image_map_url, image_map_name)
37
- image_map = {}
38
  with open(image_map_name, 'r') as f:
39
- image_map = json.load(f)
 
 
 
40
 
41
  def search(query, k=5):
42
  # Encode the query
@@ -49,8 +44,9 @@ def search(query, k=5):
49
  # Map the image ids to the corresponding image URLs
50
  image_urls = []
51
  for i in I[0]:
52
- image_id = str(i)
53
- image_url = image_map[image_id]
 
54
  image_urls.append(image_url)
55
 
56
  return image_urls
 
9
  from io import BytesIO
10
  from sentence_transformers import SentenceTransformer
11
 
12
+ dataset = load_dataset("imagefolder", data_files="https://huggingface.co/datasets/nlphuji/flickr30k/blob/main/flickr30k-images.zip")
 
13
 
14
  # Load the pre-trained sentence encoder
15
+ model_name = "sentence-transformers/all-distilroberta-v1"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = SentenceTransformer(model_name)
18
 
 
 
 
 
 
 
 
19
  # Load the FAISS index
20
  index_name = 'index.faiss'
21
+ index_url = 'https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/faiss_flickr8k.index'
22
  wget.download(index_url, index_name)
23
  index = faiss.read_index(index_name)
24
 
25
  # Map the image ids to the corresponding image URLs
26
+ image_map_name = 'captions.json'
27
+ image_map_url = 'https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/captions.json'
28
  wget.download(image_map_url, image_map_name)
29
+
30
  with open(image_map_name, 'r') as f:
31
+ caption_dict = json.load(f)
32
+
33
+ image_list = list(caption_dict.keys())
34
+ caption_list = list(caption_dict.values())
35
 
36
  def search(query, k=5):
37
  # Encode the query
 
44
  # Map the image ids to the corresponding image URLs
45
  image_urls = []
46
  for i in I[0]:
47
+ text_id = i
48
+ image_id = str(image_list[i])
49
+ image_url = "https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/Images/" + image_id
50
  image_urls.append(image_url)
51
 
52
  return image_urls