Spaces:
Sleeping
Sleeping
File size: 8,321 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import random
import pdb
from collections import defaultdict
import numpy
import numpy as np
import math
class PosRateControllerLength():
def __init__(self, max_length = 9, center_length = 8):
self.leng_to_controller = [PosRateController() for i in range(max_length + 1)]
self.max_length = max_length
self.center_length = center_length
self.pos_rates = []
self.lengths = []
def __call__(self, pos_num, neg_num):
# first sample the query length
length = numpy.random.normal(self.center_length, 5.0)
# cap to 1 and max_length
length = max(1, min(self.max_length, length))
length = round(length)
length = min(pos_num + neg_num, length)
pos_num, neg_num = self.leng_to_controller[length](pos_num, neg_num, desired_length = length)
return pos_num, neg_num
def update_true_pos_rate(self, pos_num, total_num):
if total_num == 0:
return
self.pos_rates.append(pos_num / total_num)
self.lengths.append(total_num)
total_num = int(min(total_num, self.max_length))
self.leng_to_controller[total_num].update_true_pos_rate(pos_num, total_num)
# if len(self.pos_rates) % 1000 == 0:
# print(self.pos_rates)
# print(self.lengths)
# for i in range(len(self.leng_to_controller)):
# print("length: ", i)
# print("overall pos rate: ", sum(self.leng_to_controller[i].pos_rates) / max(1.0, len(self.leng_to_controller[i].pos_rates)))
class PosRateController():
def __init__(self, bin_num = 10, adhoc_bin_weights = {}, control_length = -1):
self.bins = [1.0 / bin_num * i for i in range(bin_num + 1)]
self.bin_counter = [0 for i in range(bin_num + 1)]
self.adhoc_bin_weights = adhoc_bin_weights # this is a list of weights for each bin
self.slack = 20 # we can allow some slack for the pos rate control
self.pos_rates = []
self.lengths = []
def _find_closest_bin(self, pos_rate, valid_bins):
valid_bins_rate = [self.bins[i] for i in valid_bins]
# determine the pos rate is in which bin
# find the closes bin to the current pos rate
bin_index = valid_bins[0]
min_diff = abs(pos_rate - valid_bins_rate[0])
for i in range(1, len(valid_bins)):
diff = abs(pos_rate - valid_bins_rate[i])
if diff < min_diff:
bin_index = valid_bins[i]
min_diff = diff
if diff == min_diff and random.random() > 0.5:
bin_index = valid_bins[i]
min_diff = diff
return bin_index
def __call__(self, pos_num, neg_num, desired_length = -1):
if pos_num == 0 and neg_num == 0:
return 0, 0
if pos_num == 1 and neg_num == 0:
return 1, 0
pos_now = pos_num / (pos_num + neg_num)
min_bin_counter = min([self.bin_counter[i] * self.adhoc_bin_weights.get(i, 1.0) for i in range(len(self.bin_counter)) ])
valid_bins = [i for i in range(len(self.bin_counter)) if self.bin_counter[i] * self.adhoc_bin_weights.get(i, 1.0) <= min_bin_counter + self.slack] # these are the bins this example could go to
bin_index = random.choice(valid_bins)
#self._find_closest_bin(pos_now, valid_bins)
if desired_length > 0:
# control to the desired length
desired_pos = round(desired_length * self.bins[bin_index])
pos_num = min(pos_num, desired_pos)
if self.bins[bin_index] == 0:
neg_num = min(neg_num, desired_length)
else:
neg_num = min(neg_num, round(pos_num / self.bins[bin_index] * (1 - self.bins[bin_index])))
else:
# let's control the pos_rate to the desired rate
if pos_now == self.bins[bin_index]:
pass
elif pos_now < self.bins[bin_index]:
# this means we need to drop some negative examples
neg_num = round(pos_num / self.bins[bin_index] - pos_num)
else:
# this means we need to drop some positive examples
pos_num = round(neg_num * self.bins[bin_index] / (1 - self.bins[bin_index]))
# new_bin_index = self._find_closest_bin(pos_num / (pos_num + neg_num), list(range(len(self.bins))))
# if new_bin_index != bin_index and len(self.pos_rates) > 1000:
# pdb.set_trace()
# self.bin_counter[new_bin_index] += 1
# self.pos_rates.append(pos_num / (pos_num + neg_num))
# make sure we don't have all 0s
if pos_num == 0 and neg_num == 0:
pos_num = 1
neg_num = 0
return pos_num, neg_num
def update_true_pos_rate(self, pos_num, total_num):
if total_num == 0: # ignore
return
pos_rate = pos_num / total_num
bin_index = self._find_closest_bin(pos_rate, list(range(len(self.bins))))
self.bin_counter[bin_index] += 1
self.pos_rates.append(pos_rate)
self.lengths.append(total_num)
# if len(self.pos_rates) % 1000 == 0:
# print(self.pos_rates)
# for i in self.pos_rate_by_lengths:
# print(i, len(self.pos_rate_by_lengths[i]), sum(self.pos_rate_by_lengths[i]) / len(self.pos_rate_by_lengths[i]))
def report(self,):
#print(self.lengths)
print(np.mean(self.lengths), self.bin_counter)
from scipy.stats import norm
class PosRateControllerV2():
def __init__(self, max_length, center_length, scale = 4.0):
self.max_length = max_length
self.center_length = center_length
self.bins = defaultdict(int)
for i in range(1, max_length + 1):
for j in range(0, i + 1):
self.bins[(i, j)] = 0
# calculate the weights according to a normal distribution centered on center_length
dis = norm(loc = center_length, scale = scale)
self.weights = {}
for i in range(1, max_length+1):
self.weights[i] = dis.cdf(i + 0.5) - dis.cdf(i - 0.5)
# print(self.weights)
# renormalize the weights
total_weight = sum(self.weights.values())
for i in self.weights:
self.weights[i] /= total_weight
self.weights_pos_rate = {}
# do a slight reweight
self.pos_rates = []
self.slack = 10
def __call__(self, pos_num, neg_num, max_cap_num = -1):
# find the most good matching bin
valid_keys = []
for key in self.bins:
if key[0] <= pos_num + neg_num and key[1] <= pos_num and key[0] - key[1] <= neg_num and (max_cap_num == -1 or key[0] <= max_cap_num):
valid_keys.append(key)
# find the min count in the valid keys
if len(valid_keys) == 0:
print(pos_num, neg_num)
return pos_num, neg_num
min_counter = min([self.bins[key] / self.weights[key[0]] for key in valid_keys])
valid_keys = [key for key in valid_keys if self.bins[key] / self.weights[key[0]] <= min_counter + self.slack] # rescreened
# find the counter where we drop the minimal number of examples
closest_key = None
min_diff = 100
for key in valid_keys:
diff = abs(key[1] - pos_num)
if diff < min_diff:
closest_key = key
min_diff = diff
if closest_key is None:
return pos_num, neg_num
return closest_key[1], closest_key[0] - closest_key[1]
def update_true_pos_rate(self, pos_num, total_num):
if total_num == 0:
return
self.bins[(total_num, pos_num)] += 1
self.pos_rates.append(pos_num / total_num)
def report(self):
if len(self.pos_rates) % 1000 != 0:
return
for i in range(1, self.max_length + 1):
print("length", i, sum([self.bins[(i, j)] for j in range(0, i + 1)]))
for j in range(0, i + 1):
print(" pos", j, " ", self.bins[(i, j)])
print("\n\n")
'''
import matplotlib.pyplot as plt
# drop a histogram
plt.hist(data, bins = 10)
plt.show()
'''
|