youtube-recommender / recommendation.py
withana's picture
initial commit
c57ad5c
raw
history blame
1.65 kB
from datasets import load_dataset, concatenate_datasets
from sentence_transformers import SentenceTransformer
from torchvision import transforms
from models.encoder import Encoder
from indexer import Indexer
import dotenv
import torch
import os
dotenv.load_dotenv()
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
encoder = Encoder()
encoder.load_state_dict(torch.load('./models/encoder.bin', map_location=torch.device('cpu')))
dataset = load_dataset("Ransaka/youtube_recommendation_data", token=os.environ.get('HF'))
dataset = concatenate_datasets([dataset['train'], dataset['test']])
latent_data = torch.load("data/latent_data_final.bin")
embeddings = torch.load("data/embeddings.bin")
text_embedding_index = Indexer(embeddings)
image_embedding_index = Indexer(latent_data)
def get_recommendations(image, title):
# title = [dataset[product_id]['title']]
title_embeds = torch.randn(1,768)#model.encode(title, normalize_embeddings=True)
image = transforms.ToTensor()(image.convert("L"))
image_embeds = encoder(image).detach().numpy()
image_candidates = image_embedding_index.topk(image_embeds)
title_candidates = text_embedding_index.topk(title_embeds)
final_candidates = []
final_candidates.append(list(image_candidates[0]))
final_candidates.append(list(title_candidates[0]))
final_candidates = sum(final_candidates,[])
final_candidates = list(set(final_candidates))
results_dict = {"image":[], "title":[]}
for candidate in final_candidates:
results_dict['image'].append(dataset['image'][candidate])
results_dict['title'].append(dataset['title'][candidate])
return results_dict