File size: 281 Bytes
fbf7e95
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch.nn as nn


class AttributeSelector(nn.Module):
    def __init__(self, attrs):
        super().__init__()

        self.attrs = attrs

    def forward(self, sim: dict) -> dict:
        sim = {key: sim[key] for key in self.attrs if key in sim.keys()}
        return sim