File size: 467 Bytes
b94cb82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) Meta Platforms, Inc. and affiliates

import json
import numpy as np
import os
import random

from tqdm import tqdm


def balance_sampling(matched_entry_ids, entry_prob):
    # this can be placed in a pipeline or on-the-fly in a data loader.
    # see a numpy impl. at metaclip.indexing.balance_sampling.balance_sampling
    for entry_id in matched_entry_ids:
        if random.random() < entry_prob[entry_id]:
            return True
    return False