Spaces:
Sleeping
Sleeping
Commit
•
25ae722
1
Parent(s):
86ba518
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,9 @@ import pickle
|
|
11 |
import pickletools
|
12 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
13 |
from sklearn.preprocessing import normalize, OneHotEncoder
|
|
|
|
|
|
|
14 |
|
15 |
# loading the train dataset
|
16 |
with open('clip_train.pkl', 'rb') as f:
|
@@ -29,6 +32,10 @@ with open('clip_test.pkl', 'rb') as f:
|
|
29 |
test_yv = temp_d['label']
|
30 |
test_yt = temp_d['label']
|
31 |
|
|
|
|
|
|
|
|
|
32 |
enc = OneHotEncoder(sparse=False)
|
33 |
enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
|
34 |
train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64)
|
@@ -36,6 +43,45 @@ test_yv = enc.transform(test_yv.reshape((-1, 1))).astype(np.float64)
|
|
36 |
train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
|
37 |
test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# Map the image ids to the corresponding image URLs
|
40 |
image_map_name = 'pascal_dataset.csv'
|
41 |
df = pd.read_csv(image_map_name)
|
@@ -51,14 +97,17 @@ d = 32
|
|
51 |
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
|
52 |
text_index = faiss.read_index("text_index.index")
|
53 |
|
54 |
-
def T2Isearch(query, k=50):
|
55 |
# Encode the text query
|
56 |
-
inputs = text_tokenizer([query], padding=True, return_tensors="pt")
|
57 |
-
outputs = text_model(**inputs)
|
58 |
-
query_embedding = outputs.text_embeds
|
59 |
-
query_vector = query_embedding.detach().numpy()
|
60 |
# query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
|
61 |
-
query_vector = query_vector.reshape(1,512)
|
|
|
|
|
|
|
62 |
faiss.normalize_L2(query_vector)
|
63 |
index.nprobe = index.ntotal
|
64 |
|
@@ -66,7 +115,7 @@ def T2Isearch(query, k=50):
|
|
66 |
D, I = text_index.search(query_vector, k)
|
67 |
|
68 |
# get rank of all classes wrt to query
|
69 |
-
|
70 |
Y = train_yt
|
71 |
neighbor_ys = Y[I]
|
72 |
class_freq = np.zeros(Y.shape[1])
|
@@ -98,7 +147,7 @@ def T2Isearch(query, k=50):
|
|
98 |
if count == 5: break
|
99 |
|
100 |
query = st.text_input("Enter your search query here:")
|
101 |
-
|
102 |
if st.button("Search"):
|
103 |
if query:
|
104 |
-
T2Isearch(query)
|
|
|
11 |
import pickletools
|
12 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
13 |
from sklearn.preprocessing import normalize, OneHotEncoder
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch
|
17 |
|
18 |
# loading the train dataset
|
19 |
with open('clip_train.pkl', 'rb') as f:
|
|
|
32 |
test_yv = temp_d['label']
|
33 |
test_yt = temp_d['label']
|
34 |
|
35 |
+
test_xt_proj = np.load("test_text_proj.npy")
|
36 |
+
# test_xv_proj = np.load("test_image_proj.npy")
|
37 |
+
|
38 |
+
# encoding the labels
|
39 |
enc = OneHotEncoder(sparse=False)
|
40 |
enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
|
41 |
train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64)
|
|
|
43 |
train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
|
44 |
test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
|
45 |
|
46 |
+
# # Model structure
|
47 |
+
# torch.manual_seed(3074)
|
48 |
+
# class imgModel(nn.Module):
|
49 |
+
# def __init__(self, in_features, out_features):
|
50 |
+
# super(imgModel, self).__init__()
|
51 |
+
# self.l1 = nn.Linear(in_features=in_features, out_features=256)
|
52 |
+
# self.bn1 = nn.BatchNorm1d(256)
|
53 |
+
# self.dl1 = nn.Dropout(p=0.2)
|
54 |
+
# self.l2 = nn.Linear(in_features=256, out_features=out_features)
|
55 |
+
|
56 |
+
# def forward(self, x):
|
57 |
+
# x = self.l1(x)
|
58 |
+
# x = torch.sigmoid(x)
|
59 |
+
# x = self.dl1(x)
|
60 |
+
# x = self.bn1(x)
|
61 |
+
|
62 |
+
# x = self.l2(x)
|
63 |
+
# x = torch.tanh(x)
|
64 |
+
# return x
|
65 |
+
|
66 |
+
torch.manual_seed(3074)
|
67 |
+
class txtModel(nn.Module):
|
68 |
+
def __init__(self, in_features, out_features):
|
69 |
+
super(txtModel, self).__init__()
|
70 |
+
self.l1 = nn.Linear(in_features=in_features, out_features=256)
|
71 |
+
self.bn1 = nn.BatchNorm1d(256)
|
72 |
+
self.dl2= nn.Dropout(p=0.2)
|
73 |
+
self.l2 = nn.Linear(in_features=256, out_features=out_features)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
# print(x[0].shape)
|
77 |
+
x = self.l1(x)
|
78 |
+
x = torch.sigmoid(x)
|
79 |
+
x = self.dl2(x)
|
80 |
+
x = self.bn1(x)
|
81 |
+
x = torch.tanh(self.l2(x))
|
82 |
+
# print(x[0].shape)
|
83 |
+
return x
|
84 |
+
|
85 |
# Map the image ids to the corresponding image URLs
|
86 |
image_map_name = 'pascal_dataset.csv'
|
87 |
df = pd.read_csv(image_map_name)
|
|
|
97 |
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
|
98 |
text_index = faiss.read_index("text_index.index")
|
99 |
|
100 |
+
def T2Isearch(query, i, k=50):
|
101 |
# Encode the text query
|
102 |
+
# inputs = text_tokenizer([query], padding=True, return_tensors="pt")
|
103 |
+
# outputs = text_model(**inputs)
|
104 |
+
# query_embedding = outputs.text_embeds
|
105 |
+
# query_vector = query_embedding.detach().numpy()
|
106 |
# query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
|
107 |
+
# query_vector = query_vector.reshape(1,512)
|
108 |
+
|
109 |
+
query_vector = test_xt_proj[i-1]
|
110 |
+
query_vector = query_vector.reshape(1,32)
|
111 |
faiss.normalize_L2(query_vector)
|
112 |
index.nprobe = index.ntotal
|
113 |
|
|
|
115 |
D, I = text_index.search(query_vector, k)
|
116 |
|
117 |
# get rank of all classes wrt to query
|
118 |
+
|
119 |
Y = train_yt
|
120 |
neighbor_ys = Y[I]
|
121 |
class_freq = np.zeros(Y.shape[1])
|
|
|
147 |
if count == 5: break
|
148 |
|
149 |
query = st.text_input("Enter your search query here:")
|
150 |
+
i = st.text_input("Enter the index of test set from 1 - 200")
|
151 |
if st.button("Search"):
|
152 |
if query:
|
153 |
+
T2Isearch(query, i)
|