import numpy as np import torch import torch.utils.data as data class EmbDataset(data.Dataset): def __init__(self,data_path): self.data_path = data_path # self.embeddings = np.fromfile(data_path, dtype=np.float32).reshape(16859,-1) self.embeddings = np.load(data_path) self.dim = self.embeddings.shape[-1] def __getitem__(self, index): emb = self.embeddings[index] tensor_emb=torch.FloatTensor(emb) return tensor_emb def __len__(self): return len(self.embeddings)