File size: 546 Bytes
05744dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import numpy as np
import torch
import torch.utils.data as data
class EmbDataset(data.Dataset):
def __init__(self,data_path):
self.data_path = data_path
# self.embeddings = np.fromfile(data_path, dtype=np.float32).reshape(16859,-1)
self.embeddings = np.load(data_path)
self.dim = self.embeddings.shape[-1]
def __getitem__(self, index):
emb = self.embeddings[index]
tensor_emb=torch.FloatTensor(emb)
return tensor_emb
def __len__(self):
return len(self.embeddings)
|