Update clip_inferencing.py
Browse files- clip_inferencing.py +0 -42
clip_inferencing.py
CHANGED
@@ -1,15 +1,10 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
-
from transformers import DistilBertTokenizer
|
4 |
-
from tqdm.autonotebook import tqdm
|
5 |
import pickle
|
6 |
|
7 |
from clip_model import CLIPModel
|
8 |
from configuration import CFG
|
9 |
|
10 |
-
import matplotlib.pyplot as plt
|
11 |
-
import cv2
|
12 |
-
|
13 |
def load_model(model_path):
|
14 |
model = CLIPModel().to(CFG.device)
|
15 |
model.load_state_dict(torch.load(model_path, map_location=CFG.device))
|
@@ -25,41 +20,4 @@ def load_image_embeddings():
|
|
25 |
with open("pickles/image_embeddings.pkl", 'rb') as file:
|
26 |
image_embeddings = pickle.load(file)
|
27 |
return image_embeddings
|
28 |
-
|
29 |
-
def find_matches(model, image_embeddings, query, image_filenames, n=9):
|
30 |
-
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
|
31 |
-
encoded_query = tokenizer([query])
|
32 |
-
batch = {
|
33 |
-
key: torch.tensor(values).to(CFG.device)
|
34 |
-
for key, values in encoded_query.items()
|
35 |
-
}
|
36 |
-
with torch.no_grad():
|
37 |
-
text_features = model.text_encoder(
|
38 |
-
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
|
39 |
-
)
|
40 |
-
text_embeddings = model.text_projection(text_features)
|
41 |
-
|
42 |
-
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
|
43 |
-
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
|
44 |
-
dot_similarity = text_embeddings_n @ image_embeddings_n.T
|
45 |
-
|
46 |
-
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
|
47 |
-
matches = [image_filenames[idx] for idx in indices[::5]]
|
48 |
-
|
49 |
-
_, axes = plt.subplots(3, 3, figsize=(10, 10))
|
50 |
-
for match, ax in zip(matches, axes.flatten()):
|
51 |
-
image = cv2.imread(f"Images/{match}")
|
52 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
53 |
-
ax.imshow(image)
|
54 |
-
ax.axis("off")
|
55 |
-
|
56 |
-
plt.show()
|
57 |
-
|
58 |
-
def inference(query):
|
59 |
-
valid_df = load_df()
|
60 |
-
image_embeddings = load_image_embeddings()
|
61 |
-
find_matches(load_model(model_path="model/best.pt"),
|
62 |
-
image_embeddings,
|
63 |
-
query=query,
|
64 |
-
image_filenames=valid_df['image'].values, n=9)
|
65 |
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
|
|
|
|
3 |
import pickle
|
4 |
|
5 |
from clip_model import CLIPModel
|
6 |
from configuration import CFG
|
7 |
|
|
|
|
|
|
|
8 |
def load_model(model_path):
|
9 |
model = CLIPModel().to(CFG.device)
|
10 |
model.load_state_dict(torch.load(model_path, map_location=CFG.device))
|
|
|
20 |
with open("pickles/image_embeddings.pkl", 'rb') as file:
|
21 |
image_embeddings = pickle.load(file)
|
22 |
return image_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|