|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def retrieve_st_by_image(image_embeddings, all_text_embeddings, dataframe, k=3): |
|
|
""" |
|
|
Retrieves the top-k most similar ST based on the similarity between ST embeddings and image embeddings. |
|
|
|
|
|
:param image_embeddings: A numpy array or torch tensor containing image embeddings (shape: [1, embedding_dim]). |
|
|
:param all_text_embeddings: A numpy array or torch tensor containing ST embeddings (shape: [n_samples, embedding_dim]). |
|
|
:param dataframe: A pandas DataFrame containing information about the ST samples, specifically the image indices in the 'img_idx' column. |
|
|
:param k: The number of top similar samples to retrieve. Default is 3. |
|
|
:return: A list of the filenames or indices corresponding to the top-k similar samples. |
|
|
""" |
|
|
|
|
|
|
|
|
dot_similarity = image_embeddings @ all_text_embeddings.T |
|
|
|
|
|
|
|
|
values, indices = torch.topk(dot_similarity.squeeze(0), k) |
|
|
|
|
|
|
|
|
image_filenames = dataframe['img_idx'].values |
|
|
matches = [image_filenames[idx] for idx in indices] |
|
|
|
|
|
return matches |
|
|
|
|
|
|
|
|
|