Spaces:
Sleeping
Sleeping
Commit
·
2d30c9d
1
Parent(s):
9217b31
Update app.py
Browse files
app.py
CHANGED
@@ -15,28 +15,24 @@ from sklearn.preprocessing import normalize, OneHotEncoder
|
|
15 |
# loading the train dataset
|
16 |
with open('clip_train.pkl', 'rb') as f:
|
17 |
temp_d = pickle.load(f)
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
train_yt = temp_d['label'] # Array of labels
|
22 |
-
|
23 |
-
|
24 |
-
# train_yt = np.load("train_yt.npy")
|
25 |
|
26 |
# loading the test dataset
|
27 |
with open('clip_test.pkl', 'rb') as f:
|
28 |
temp_d = pickle.load(f)
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
test_yt = temp_d['label']
|
33 |
|
34 |
-
# test_xt = np.load("test_xt.npy")
|
35 |
-
|
36 |
enc = OneHotEncoder(sparse=False)
|
37 |
enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
|
38 |
-
|
39 |
-
|
40 |
train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
|
41 |
test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
|
42 |
|
@@ -55,18 +51,15 @@ d = 32
|
|
55 |
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
|
56 |
text_index = faiss.read_index("text_index.index")
|
57 |
|
58 |
-
def T2Isearch(query,
|
59 |
# Encode the text query
|
60 |
-
inputs = text_tokenizer([query
|
61 |
outputs = text_model(**inputs)
|
62 |
query_embedding = outputs.text_embeds
|
63 |
query_vector = query_embedding.detach().numpy()
|
64 |
-
query_vector =
|
65 |
-
query_vector = query_vector.reshape(1,1024)
|
66 |
faiss.normalize_L2(query_vector)
|
67 |
-
|
68 |
-
|
69 |
-
# text_index.nprobe = 100
|
70 |
|
71 |
# Search for the nearest neighbors in the FAISS text index
|
72 |
D, I = text_index.search(query_vector, k)
|
@@ -104,7 +97,7 @@ def T2Isearch(query,focussed_word, k=50):
|
|
104 |
if count == 5: break
|
105 |
|
106 |
query = st.text_input("Enter your search query here:")
|
107 |
-
|
108 |
if st.button("Search"):
|
109 |
if query:
|
110 |
-
T2Isearch(query
|
|
|
15 |
# loading the train dataset
|
16 |
with open('clip_train.pkl', 'rb') as f:
|
17 |
temp_d = pickle.load(f)
|
18 |
+
train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray
|
19 |
+
train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray
|
20 |
+
train_yv = temp_d['label'] # Array of labels
|
21 |
train_yt = temp_d['label'] # Array of labels
|
22 |
+
ids = list(temp_d['ids']) # image names == len(images)
|
|
|
|
|
23 |
|
24 |
# loading the test dataset
|
25 |
with open('clip_test.pkl', 'rb') as f:
|
26 |
temp_d = pickle.load(f)
|
27 |
+
test_xv = temp_d['image'].astype(np.float64)
|
28 |
+
test_xt = temp_d['text'].astype(np.float64)
|
29 |
+
test_yv = temp_d['label']
|
30 |
test_yt = temp_d['label']
|
31 |
|
|
|
|
|
32 |
enc = OneHotEncoder(sparse=False)
|
33 |
enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
|
34 |
+
train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64)
|
35 |
+
test_yv = enc.transform(test_yv.reshape((-1, 1))).astype(np.float64)
|
36 |
train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
|
37 |
test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
|
38 |
|
|
|
51 |
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
|
52 |
text_index = faiss.read_index("text_index.index")
|
53 |
|
54 |
+
def T2Isearch(query, k=50):
|
55 |
# Encode the text query
|
56 |
+
inputs = text_tokenizer([query], padding=True, return_tensors="pt")
|
57 |
outputs = text_model(**inputs)
|
58 |
query_embedding = outputs.text_embeds
|
59 |
query_vector = query_embedding.detach().numpy()
|
60 |
+
query_vector = query_vector.reshape(1,512)
|
|
|
61 |
faiss.normalize_L2(query_vector)
|
62 |
+
index.nprobe = index.ntotal
|
|
|
|
|
63 |
|
64 |
# Search for the nearest neighbors in the FAISS text index
|
65 |
D, I = text_index.search(query_vector, k)
|
|
|
97 |
if count == 5: break
|
98 |
|
99 |
query = st.text_input("Enter your search query here:")
|
100 |
+
|
101 |
if st.button("Search"):
|
102 |
if query:
|
103 |
+
T2Isearch(query)
|