Spaces:
Runtime error
Runtime error
File size: 2,172 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import random
from colbert.infra.provenance import Provenance
from utility.utils.save_metadata import save_metadata
from utility.supervision.triples import sample_for_query
from colbert.utils.utils import print_message
from colbert.data.ranking import Ranking
from colbert.data.examples import Examples
MAX_NUM_TRIPLES = 40_000_000
class Triples:
def __init__(self, ranking, seed=12345):
random.seed(seed) # TODO: Use internal RNG instead..
self.seed = seed
ranking = Ranking.cast(ranking)
self.ranking_provenance = ranking.provenance()
self.qid2rankings = ranking.todict()
def create(self, positives, depth):
assert all(len(x) == 2 for x in positives)
assert all(maxBest <= maxDepth for maxBest, maxDepth in positives), positives
self.positives = positives
self.depth = depth
Triples = []
NonEmptyQIDs = 0
for processing_idx, qid in enumerate(self.qid2rankings):
l = sample_for_query(qid, self.qid2rankings[qid], positives, depth, False, None)
NonEmptyQIDs += (len(l) > 0)
Triples.extend(l)
if processing_idx % (10_000) == 0:
print_message(f"#> Done with {processing_idx+1} questions!\t\t "
f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.")
print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..")
print_message(f"#> len(Triples) = {len(Triples)}")
if len(Triples) > MAX_NUM_TRIPLES:
Triples = random.sample(Triples, MAX_NUM_TRIPLES)
### Prepare the triples ###
print_message("#> Shuffling the triples...")
random.shuffle(Triples)
self.Triples = Examples(data=Triples)
return Triples
def save(self, new_path):
provenance = Provenance()
provenance.source = 'Triples::create'
provenance.seed = self.seed
provenance.positives = self.positives
provenance.depth = self.depth
provenance.ranking = self.ranking_provenance
Examples(data=self.Triples, provenance=provenance).save(new_path)
|