Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # coding=utf-8 | |
| import torch | |
| from data.parser.json_parser import example_from_json | |
| class AbstractParser(torch.utils.data.Dataset): | |
| def __init__(self, fields, data, filter_pred=None): | |
| super(AbstractParser, self).__init__() | |
| self.examples = [example_from_json(d, fields) for _, d in sorted(data.items())] | |
| if isinstance(fields, dict): | |
| fields, field_dict = [], fields | |
| for field in field_dict.values(): | |
| if isinstance(field, list): | |
| fields.extend(field) | |
| else: | |
| fields.append(field) | |
| if filter_pred is not None: | |
| make_list = isinstance(self.examples, list) | |
| self.examples = filter(filter_pred, self.examples) | |
| if make_list: | |
| self.examples = list(self.examples) | |
| self.fields = dict(fields) | |
| # Unpack field tuples | |
| for n, f in list(self.fields.items()): | |
| if isinstance(n, tuple): | |
| self.fields.update(zip(n, f)) | |
| del self.fields[n] | |
| def __getitem__(self, i): | |
| item = self.examples[i] | |
| processed_item = {} | |
| for (name, field) in self.fields.items(): | |
| if field is not None: | |
| processed_item[name] = field.process(getattr(item, name), device=None) | |
| return processed_item | |
| def __len__(self): | |
| return len(self.examples) | |
| def get_examples(self, attr): | |
| if attr in self.fields: | |
| for x in self.examples: | |
| yield getattr(x, attr) | |