| | |
| |
|
| | import torch |
| |
|
| |
|
| | class Embedding(object): |
| |
|
| | def __init__(self, tokens, vectors, unk=None): |
| | super(Embedding, self).__init__() |
| | self.tokens = tokens |
| | self.vectors = torch.tensor([v[0] for v in vectors]) |
| | print(self.vectors.size(0)) |
| | self.pretrained = {w: v for w, v in zip(tokens, vectors)} |
| | self.unk = '[UNK]' |
| |
|
| | def __len__(self): |
| | return len(self.tokens) |
| |
|
| | def __contains__(self, token): |
| | return token in self.pretrained |
| |
|
| | @property |
| | def dim(self): |
| | return self.vectors.size(0) |
| |
|
| | @property |
| | def unk_index(self): |
| | if self.unk is not None: |
| | return self.tokens.index(self.unk) |
| | else: |
| | raise AttributeError |
| |
|
| | @classmethod |
| | def load(cls, path, unk=None): |
| | with open(path, 'r') as f: |
| | lines = [line for line in f] |
| | splits = [line.split() for line in lines] |
| | tokens, vectors = zip(*[(s[0], list(map(float, s[1:]))) |
| | for s in splits]) |
| |
|
| | return cls(tokens, vectors, unk=unk) |
| |
|