File size: 8,321 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import random
import pdb
from collections import defaultdict
import numpy
import numpy as np
import math
class PosRateControllerLength():
    def __init__(self, max_length = 9, center_length = 8):
        self.leng_to_controller = [PosRateController() for i in range(max_length + 1)]
        self.max_length = max_length
        self.center_length = center_length
        self.pos_rates = []
        self.lengths = []
    def __call__(self, pos_num, neg_num):
        # first sample the query length
        length = numpy.random.normal(self.center_length, 5.0)
        # cap to 1 and max_length
        length = max(1, min(self.max_length, length))
        length = round(length)
        length = min(pos_num + neg_num, length)

        pos_num, neg_num = self.leng_to_controller[length](pos_num, neg_num, desired_length = length)
        return pos_num, neg_num
    
    def update_true_pos_rate(self, pos_num, total_num):
        if total_num == 0:
            return
        self.pos_rates.append(pos_num / total_num)
        self.lengths.append(total_num)
        total_num = int(min(total_num, self.max_length))
        self.leng_to_controller[total_num].update_true_pos_rate(pos_num, total_num)
        
        # if len(self.pos_rates) % 1000 == 0:
        #     print(self.pos_rates)
        #     print(self.lengths)
        #     for i in range(len(self.leng_to_controller)):
        #         print("length: ", i)
        #         print("overall pos rate: ", sum(self.leng_to_controller[i].pos_rates) / max(1.0, len(self.leng_to_controller[i].pos_rates)))
    
class PosRateController():
    def __init__(self, bin_num = 10, adhoc_bin_weights = {}, control_length = -1): 
        self.bins = [1.0 / bin_num * i for i in range(bin_num + 1)]
        self.bin_counter = [0 for i in range(bin_num + 1)]

        self.adhoc_bin_weights = adhoc_bin_weights # this is a list of weights for each bin
        self.slack = 20 # we can allow some slack for the pos rate control
        self.pos_rates = []
        self.lengths = []
        
    def _find_closest_bin(self, pos_rate, valid_bins):
        valid_bins_rate = [self.bins[i] for i in valid_bins]
        # determine the pos rate is in which bin
        # find the closes bin to the current pos rate
        bin_index = valid_bins[0]
        min_diff = abs(pos_rate - valid_bins_rate[0])

        for i in range(1, len(valid_bins)):
            diff = abs(pos_rate - valid_bins_rate[i])
            if diff < min_diff:
                bin_index = valid_bins[i]
                min_diff = diff
            if diff == min_diff and random.random() > 0.5:
                bin_index = valid_bins[i]
                min_diff = diff
        return bin_index

    def __call__(self, pos_num, neg_num, desired_length = -1):
        if pos_num == 0 and neg_num == 0:
            return 0, 0
        if pos_num == 1 and neg_num == 0:
            return 1, 0

        pos_now = pos_num / (pos_num + neg_num)
        
        min_bin_counter = min([self.bin_counter[i] * self.adhoc_bin_weights.get(i, 1.0) for i in range(len(self.bin_counter)) ])
        valid_bins = [i for i in range(len(self.bin_counter)) if self.bin_counter[i] * self.adhoc_bin_weights.get(i, 1.0) <= min_bin_counter + self.slack] # these are the bins this example could go to
        bin_index = random.choice(valid_bins)
        #self._find_closest_bin(pos_now, valid_bins)

        if desired_length > 0:
            # control to the desired length
            desired_pos = round(desired_length * self.bins[bin_index])
            pos_num = min(pos_num, desired_pos)
            if self.bins[bin_index] == 0:
                neg_num = min(neg_num, desired_length)
            else:
                neg_num = min(neg_num, round(pos_num / self.bins[bin_index] * (1 - self.bins[bin_index])))
        else:
            # let's control the pos_rate to the desired rate
            if pos_now == self.bins[bin_index]:
                pass
            elif pos_now < self.bins[bin_index]:
                # this means we need to drop some negative examples
                neg_num = round(pos_num / self.bins[bin_index] - pos_num)
            else:
                # this means we need to drop some positive examples
                pos_num = round(neg_num * self.bins[bin_index] / (1 - self.bins[bin_index]))
        
        # new_bin_index = self._find_closest_bin(pos_num / (pos_num + neg_num), list(range(len(self.bins))))
        # if new_bin_index != bin_index and len(self.pos_rates) > 1000:
        #     pdb.set_trace()

        # self.bin_counter[new_bin_index] += 1
        # self.pos_rates.append(pos_num / (pos_num + neg_num))
        # make sure we don't have all 0s
        if pos_num == 0 and neg_num == 0:
            pos_num = 1
            neg_num = 0

        return pos_num, neg_num
    
    def update_true_pos_rate(self, pos_num, total_num):
        if total_num == 0: # ignore
            return
        pos_rate = pos_num / total_num
        bin_index = self._find_closest_bin(pos_rate, list(range(len(self.bins))))
        self.bin_counter[bin_index] += 1
        self.pos_rates.append(pos_rate)
        self.lengths.append(total_num)
        # if len(self.pos_rates) % 1000 == 0:
        #     print(self.pos_rates)
        #     for i in self.pos_rate_by_lengths:
        #         print(i, len(self.pos_rate_by_lengths[i]), sum(self.pos_rate_by_lengths[i]) / len(self.pos_rate_by_lengths[i]))
    def report(self,):
        #print(self.lengths)
        print(np.mean(self.lengths), self.bin_counter)
from scipy.stats import norm
class PosRateControllerV2():
    def __init__(self, max_length, center_length, scale = 4.0):
        self.max_length = max_length
        self.center_length = center_length
        self.bins = defaultdict(int)
        for i in range(1, max_length + 1):
            for j in range(0, i + 1):
                self.bins[(i, j)] = 0
        
        # calculate the weights according to a normal distribution centered on center_length
        dis = norm(loc = center_length, scale = scale)

        self.weights = {}
        for i in range(1, max_length+1):
            self.weights[i] = dis.cdf(i + 0.5) - dis.cdf(i - 0.5)
        # print(self.weights)
        # renormalize the weights
        total_weight = sum(self.weights.values())
        for i in self.weights:
            self.weights[i] /= total_weight

        self.weights_pos_rate = {}
        
        # do a slight reweight
        self.pos_rates = []

        self.slack = 10

    def __call__(self, pos_num, neg_num, max_cap_num = -1):
        # find the most good matching bin
        
        valid_keys = []
        for key in self.bins:
            if key[0] <= pos_num + neg_num and key[1] <= pos_num and key[0] - key[1] <= neg_num and (max_cap_num == -1 or key[0] <= max_cap_num):
                valid_keys.append(key)
        # find the min count in the valid keys
        if len(valid_keys) == 0:
            print(pos_num, neg_num)
            return pos_num, neg_num
        min_counter = min([self.bins[key] / self.weights[key[0]] for key in valid_keys])
        valid_keys = [key for key in valid_keys if self.bins[key] / self.weights[key[0]] <= min_counter + self.slack] # rescreened

        # find the counter where we drop the minimal number of examples
        closest_key = None
        min_diff = 100
        for key in valid_keys:
            diff = abs(key[1] - pos_num)
            if diff < min_diff:
                closest_key = key
                min_diff = diff
    
        if closest_key is None:
            return pos_num, neg_num
    
        return closest_key[1], closest_key[0] - closest_key[1]
    
    def update_true_pos_rate(self, pos_num, total_num):
        if total_num == 0:
            return
        self.bins[(total_num, pos_num)] += 1
        self.pos_rates.append(pos_num / total_num)
    
    def report(self):
        if len(self.pos_rates) % 1000 != 0:
            return

        for i in range(1, self.max_length + 1):
            print("length", i, sum([self.bins[(i, j)] for j in range(0, i + 1)]))
            for j in range(0, i + 1):
                print("  pos", j, " ", self.bins[(i, j)])
        print("\n\n")
        

'''
import matplotlib.pyplot as plt
# drop a histogram
plt.hist(data, bins = 10)
plt.show()
'''