shivangibithel commited on
Commit
d6c88ae
1 Parent(s): b68c187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -12
app.py CHANGED
@@ -1,20 +1,71 @@
 
1
  import streamlit as st
2
- from transformers import pipeline
 
 
 
 
3
  from PIL import Image
 
 
4
 
5
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
 
6
 
7
- st.title("Hot Dog? Or Not?")
 
 
 
8
 
9
- file_name = st.file_uploader("Upload a hot dog candidate image")
 
 
 
 
 
10
 
11
- if file_name is not None:
12
- col1, col2 = st.columns(2)
 
 
 
13
 
14
- image = Image.open(file_name)
15
- col1.image(image, use_column_width=True)
16
- predictions = pipeline(image)
 
 
 
 
17
 
18
- col2.header("Probabilities")
19
- for p in predictions:
20
- col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
  import streamlit as st
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import faiss
6
+ import numpy as np
7
+ import wget
8
  from PIL import Image
9
+ from io import BytesIO
10
+ from sentence_transformers import SentenceTransformer
11
 
12
+ # dataset = load_dataset("nlphuji/flickr30k", streaming=True)
13
+ # df = pd.DataFrame.from_dict(dataset["train"])
14
 
15
+ # Load the pre-trained sentence encoder
16
+ model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = SentenceTransformer(model_name)
19
 
20
+ # # Load the pre-trained image model
21
+ # image_model_name = 'image_model.ckpt'
22
+ # image_model_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/images/vqvae.png'
23
+ # wget.download(image_model_url, image_model_name)
24
+ # image_model = torch.load(image_model_name, map_location=torch.device('cpu'))
25
+ # image_model.eval()
26
 
27
+ # Load the FAISS index
28
+ index_name = 'index.faiss'
29
+ index_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/faiss_files/faiss.index'
30
+ wget.download(index_url, index_name)
31
+ index = faiss.read_index(index_name)
32
 
33
+ # Map the image ids to the corresponding image URLs
34
+ image_map_name = 'image_map.json'
35
+ image_map_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/faiss_files/image_map.json'
36
+ wget.download(image_map_url, image_map_name)
37
+ image_map = {}
38
+ with open(image_map_name, 'r') as f:
39
+ image_map = json.load(f)
40
 
41
+ def search(query, k=5):
42
+ # Encode the query
43
+ query_tokens = tokenizer.encode(query, return_tensors='pt')
44
+ query_embedding = model.encode(query_tokens).detach().numpy()
45
+
46
+ # Search for the nearest neighbors in the FAISS index
47
+ D, I = index.search(query_embedding, k)
48
+
49
+ # Map the image ids to the corresponding image URLs
50
+ image_urls = []
51
+ for i in I[0]:
52
+ image_id = str(i)
53
+ image_url = image_map[image_id]
54
+ image_urls.append(image_url)
55
+
56
+ return image_urls
57
+
58
+ st.title("Image Search App")
59
+
60
+ query = st.text_input("Enter your search query here:")
61
+ if st.button("Search"):
62
+ if query:
63
+ image_urls = search(query)
64
+
65
+ # Display the images
66
+ st.image(image_urls, width=200)
67
+
68
+ if __name__ == '__main__':
69
+ st.set_page_config(page_title='Image Search App', layout='wide')
70
+ st.cache(allow_output_mutation=True)
71
+ run_app()