File size: 2,774 Bytes
58627fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import tqdm
import ujson
from colbert.infra.provenance import Provenance

from colbert.infra.run import Run
from colbert.utils.utils import print_message, groupby_first_item
from utility.utils.save_metadata import get_metadata_only


def numericize(v):
    if '.' in v:
        return float(v)

    return int(v)


def load_ranking(path):  # works with annotated and un-annotated ranked lists
    print_message("#> Loading the ranked lists from", path)

    with open(path) as f:
        return [list(map(numericize, line.strip().split('\t'))) for line in f]


class Ranking:
    def __init__(self, path=None, data=None, metrics=None, provenance=None):
        self.__provenance = provenance or path or Provenance()
        self.data = self._prepare_data(data or self._load_file(path))

    def provenance(self):
        return self.__provenance
    
    def toDict(self):
        return {'provenance': self.provenance()}

    def _prepare_data(self, data):
        # TODO: Handle list of lists???
        if isinstance(data, dict):
            self.flat_ranking = [(qid, *rest) for qid, subranking in data.items() for rest in subranking]
            return data

        self.flat_ranking = data
        return groupby_first_item(tqdm.tqdm(self.flat_ranking))

    def _load_file(self, path):
        return load_ranking(path)

    def todict(self):
        return dict(self.data)

    def tolist(self):
        return list(self.flat_ranking)

    def items(self):
        return self.data.items()

    def _load_tsv(self, path):
        raise NotImplementedError

    def _load_jsonl(self, path):
        raise NotImplementedError

    def save(self, new_path):
        assert 'tsv' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."

        with Run().open(new_path, 'w') as f:
            for items in self.flat_ranking:
                line = '\t'.join(map(lambda x: str(int(x) if type(x) is bool else x), items)) + '\n'
                f.write(line)

            output_path = f.name
            print_message(f"#> Saved ranking of {len(self.data)} queries and {len(self.flat_ranking)} lines to {f.name}")
        
        with Run().open(f'{new_path}.meta', 'w') as f:
            d = {}
            d['metadata'] = get_metadata_only()
            d['provenance'] = self.provenance()
            line = ujson.dumps(d, indent=4)
            f.write(line)
        
        return output_path

    @classmethod
    def cast(cls, obj):
        if type(obj) is str:
            return cls(path=obj)

        if isinstance(obj, dict) or isinstance(obj, list):
            return cls(data=obj)

        if type(obj) is cls:
            return obj

        assert False, f"obj has type {type(obj)} which is not compatible with cast()"