Spaces:
Runtime error
Runtime error
from colbert.infra.run import Run | |
import os | |
import ujson | |
from colbert.utils.utils import print_message | |
from colbert.infra.provenance import Provenance | |
from utility.utils.save_metadata import get_metadata_only | |
class Examples: | |
def __init__(self, path=None, data=None, nway=None, provenance=None): | |
self.__provenance = provenance or path or Provenance() | |
self.nway = nway | |
self.path = path | |
self.data = data or self._load_file(path) | |
def provenance(self): | |
return self.__provenance | |
def toDict(self): | |
return self.provenance() | |
def _load_file(self, path): | |
nway = self.nway + 1 if self.nway else self.nway | |
examples = [] | |
with open(path) as f: | |
for line in f: | |
example = ujson.loads(line)[:nway] | |
examples.append(example) | |
return examples | |
def tolist(self, rank=None, nranks=None): | |
""" | |
NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling. | |
In particular, each subset is perfectly represented in every batch! However, since we never | |
repeat passes over the data, we never repeat any particular triple, and the split across | |
nodes is random (since the underlying file is pre-shuffled), there's no concern here. | |
""" | |
if rank or nranks: | |
assert rank in range(nranks), (rank, nranks) | |
return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank | |
return list(self.data) | |
def save(self, new_path): | |
assert 'json' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too." | |
print_message(f"#> Writing {len(self.data) / 1000_000.0}M examples to {new_path}") | |
with Run().open(new_path, 'w') as f: | |
for example in self.data: | |
ujson.dump(example, f) | |
f.write('\n') | |
output_path = f.name | |
print_message(f"#> Saved examples with {len(self.data)} lines to {f.name}") | |
with Run().open(f'{new_path}.meta', 'w') as f: | |
d = {} | |
d['metadata'] = get_metadata_only() | |
d['provenance'] = self.provenance() | |
line = ujson.dumps(d, indent=4) | |
f.write(line) | |
return output_path | |
def cast(cls, obj, nway=None): | |
if type(obj) is str: | |
return cls(path=obj, nway=nway) | |
if isinstance(obj, list): | |
return cls(data=obj, nway=nway) | |
if type(obj) is cls: | |
assert nway is None, nway | |
return obj | |
assert False, f"obj has type {type(obj)} which is not compatible with cast()" | |