Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
from PIL import Image | |
import os | |
import pandas as pd | |
from transformers import AutoFeatureExtractor,AutoModel | |
from faiss.contrib.inspect_tools import get_flat_data | |
import pymde | |
import numpy as np | |
def get_embedding(model_name,viz_dat): | |
index_file=f"./indexes/{model_name.split('/')[1]}.faiss" | |
if os.path.exists(index_file): | |
viz_dat.load_faiss_index('embeddings', index_file) | |
else: | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
# model.to("cuda:0") | |
def embed(x): | |
images=x["image"] | |
inputs = feature_extractor(images=images, return_tensors="pt") | |
# inputs.to("cuda:0") | |
outputs = model(**inputs,output_hidden_states= True) | |
final_emb=outputs.pooler_output.detach().cpu().numpy() # this line depends on the model you are using | |
x["embeddings"]=final_emb | |
return x | |
# Add embeddings to dataset | |
viz_dat = viz_dat.map(embed,batched=True,batch_size=20) | |
viz_dat.add_faiss_index(column='embeddings') | |
viz_dat.save_faiss_index('embeddings',index_file) | |
embedding_file=f"./indexes/{model_name.split('/')[1]}.npy" | |
if os.path.exists(embedding_file): | |
embedding = np.load(embedding_file) # load | |
else: | |
index=viz_dat.get_index("embeddings").faiss_index | |
embeddings=get_flat_data(index) | |
embedding=pymde.preserve_neighbors(embeddings, verbose=True).embed().numpy() | |
np.save(embedding_file, embedding) # save | |
embedding=pd.DataFrame(embedding,columns=["x","y"]) | |
embedding["image"]=viz_dat["image"] | |
embedding["gender"]=viz_dat["gender"] | |
embedding["masterCategory"]=viz_dat["masterCategory"] | |
embedding["subCategory"]=viz_dat["subCategory"] | |
return embedding |