File size: 2,251 Bytes
6917a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from gevent import monkey
def stub(*args, **kwargs):  # pylint: disable=unused-argument
    pass
monkey.patch_all = stub
import grequests
import requests

import torch
import clip
device = "cuda" if torch.cuda.is_available() else "cpu"

def encode_search_query(model, search_query):
    with torch.no_grad():
        tokenized_query = clip.tokenize(search_query)
        # print("tokenized_query: ", tokenized_query.shape)
        # Encode and normalize the search query using CLIP
        text_encoded = model.encode_text(tokenized_query.to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

        # Retrieve the feature vector
        # print("text_encoded: ", text_encoded.shape)
        return text_encoded


def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
    # Compute the similarity between the search query and each photo using the Cosine similarity
    # print("text_features: ", text_features.shape)
    # print("photo_features: ", photo_features.shape)
    similarities = (photo_features @ text_features.T).squeeze(1)

    # Sort the photos by their similarity score
    best_photo_idx = (-similarities).argsort()
    # print("best_photo_idx: ", best_photo_idx.shape)
    # print("best_photo_idx: ", best_photo_idx[:results_count])

    result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
    # print("result_list: ", len(result_list))
    # Return the photo IDs of the best matches
    return result_list


def search_unslash(search_query, photo_features, photo_ids, results_count=10):
    # Encode the search query
    text_features = encode_search_query(search_query)

    # Find the best matches
    best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)

    return best_photo_ids



def filter_invalid_urls(urls, photo_ids):
    rs = (grequests.get(u) for u in urls)
    results = grequests.map(rs)

    valid_image_ids = []
    valid_image_urls = []
    for i, res in enumerate(results):
        if res and res.status_code == 200:
            valid_image_urls.append(urls[i])
            valid_image_ids.append(photo_ids[i])

    return dict(
        image_ids=valid_image_ids,
        image_urls=valid_image_urls
    )