zdou0830's picture
desco
749745d
raw
history blame
8.32 kB
import random
import pdb
from collections import defaultdict
import numpy
import numpy as np
import math
class PosRateControllerLength():
def __init__(self, max_length = 9, center_length = 8):
self.leng_to_controller = [PosRateController() for i in range(max_length + 1)]
self.max_length = max_length
self.center_length = center_length
self.pos_rates = []
self.lengths = []
def __call__(self, pos_num, neg_num):
# first sample the query length
length = numpy.random.normal(self.center_length, 5.0)
# cap to 1 and max_length
length = max(1, min(self.max_length, length))
length = round(length)
length = min(pos_num + neg_num, length)
pos_num, neg_num = self.leng_to_controller[length](pos_num, neg_num, desired_length = length)
return pos_num, neg_num
def update_true_pos_rate(self, pos_num, total_num):
if total_num == 0:
return
self.pos_rates.append(pos_num / total_num)
self.lengths.append(total_num)
total_num = int(min(total_num, self.max_length))
self.leng_to_controller[total_num].update_true_pos_rate(pos_num, total_num)
# if len(self.pos_rates) % 1000 == 0:
# print(self.pos_rates)
# print(self.lengths)
# for i in range(len(self.leng_to_controller)):
# print("length: ", i)
# print("overall pos rate: ", sum(self.leng_to_controller[i].pos_rates) / max(1.0, len(self.leng_to_controller[i].pos_rates)))
class PosRateController():
def __init__(self, bin_num = 10, adhoc_bin_weights = {}, control_length = -1):
self.bins = [1.0 / bin_num * i for i in range(bin_num + 1)]
self.bin_counter = [0 for i in range(bin_num + 1)]
self.adhoc_bin_weights = adhoc_bin_weights # this is a list of weights for each bin
self.slack = 20 # we can allow some slack for the pos rate control
self.pos_rates = []
self.lengths = []
def _find_closest_bin(self, pos_rate, valid_bins):
valid_bins_rate = [self.bins[i] for i in valid_bins]
# determine the pos rate is in which bin
# find the closes bin to the current pos rate
bin_index = valid_bins[0]
min_diff = abs(pos_rate - valid_bins_rate[0])
for i in range(1, len(valid_bins)):
diff = abs(pos_rate - valid_bins_rate[i])
if diff < min_diff:
bin_index = valid_bins[i]
min_diff = diff
if diff == min_diff and random.random() > 0.5:
bin_index = valid_bins[i]
min_diff = diff
return bin_index
def __call__(self, pos_num, neg_num, desired_length = -1):
if pos_num == 0 and neg_num == 0:
return 0, 0
if pos_num == 1 and neg_num == 0:
return 1, 0
pos_now = pos_num / (pos_num + neg_num)
min_bin_counter = min([self.bin_counter[i] * self.adhoc_bin_weights.get(i, 1.0) for i in range(len(self.bin_counter)) ])
valid_bins = [i for i in range(len(self.bin_counter)) if self.bin_counter[i] * self.adhoc_bin_weights.get(i, 1.0) <= min_bin_counter + self.slack] # these are the bins this example could go to
bin_index = random.choice(valid_bins)
#self._find_closest_bin(pos_now, valid_bins)
if desired_length > 0:
# control to the desired length
desired_pos = round(desired_length * self.bins[bin_index])
pos_num = min(pos_num, desired_pos)
if self.bins[bin_index] == 0:
neg_num = min(neg_num, desired_length)
else:
neg_num = min(neg_num, round(pos_num / self.bins[bin_index] * (1 - self.bins[bin_index])))
else:
# let's control the pos_rate to the desired rate
if pos_now == self.bins[bin_index]:
pass
elif pos_now < self.bins[bin_index]:
# this means we need to drop some negative examples
neg_num = round(pos_num / self.bins[bin_index] - pos_num)
else:
# this means we need to drop some positive examples
pos_num = round(neg_num * self.bins[bin_index] / (1 - self.bins[bin_index]))
# new_bin_index = self._find_closest_bin(pos_num / (pos_num + neg_num), list(range(len(self.bins))))
# if new_bin_index != bin_index and len(self.pos_rates) > 1000:
# pdb.set_trace()
# self.bin_counter[new_bin_index] += 1
# self.pos_rates.append(pos_num / (pos_num + neg_num))
# make sure we don't have all 0s
if pos_num == 0 and neg_num == 0:
pos_num = 1
neg_num = 0
return pos_num, neg_num
def update_true_pos_rate(self, pos_num, total_num):
if total_num == 0: # ignore
return
pos_rate = pos_num / total_num
bin_index = self._find_closest_bin(pos_rate, list(range(len(self.bins))))
self.bin_counter[bin_index] += 1
self.pos_rates.append(pos_rate)
self.lengths.append(total_num)
# if len(self.pos_rates) % 1000 == 0:
# print(self.pos_rates)
# for i in self.pos_rate_by_lengths:
# print(i, len(self.pos_rate_by_lengths[i]), sum(self.pos_rate_by_lengths[i]) / len(self.pos_rate_by_lengths[i]))
def report(self,):
#print(self.lengths)
print(np.mean(self.lengths), self.bin_counter)
from scipy.stats import norm
class PosRateControllerV2():
def __init__(self, max_length, center_length, scale = 4.0):
self.max_length = max_length
self.center_length = center_length
self.bins = defaultdict(int)
for i in range(1, max_length + 1):
for j in range(0, i + 1):
self.bins[(i, j)] = 0
# calculate the weights according to a normal distribution centered on center_length
dis = norm(loc = center_length, scale = scale)
self.weights = {}
for i in range(1, max_length+1):
self.weights[i] = dis.cdf(i + 0.5) - dis.cdf(i - 0.5)
# print(self.weights)
# renormalize the weights
total_weight = sum(self.weights.values())
for i in self.weights:
self.weights[i] /= total_weight
self.weights_pos_rate = {}
# do a slight reweight
self.pos_rates = []
self.slack = 10
def __call__(self, pos_num, neg_num, max_cap_num = -1):
# find the most good matching bin
valid_keys = []
for key in self.bins:
if key[0] <= pos_num + neg_num and key[1] <= pos_num and key[0] - key[1] <= neg_num and (max_cap_num == -1 or key[0] <= max_cap_num):
valid_keys.append(key)
# find the min count in the valid keys
if len(valid_keys) == 0:
print(pos_num, neg_num)
return pos_num, neg_num
min_counter = min([self.bins[key] / self.weights[key[0]] for key in valid_keys])
valid_keys = [key for key in valid_keys if self.bins[key] / self.weights[key[0]] <= min_counter + self.slack] # rescreened
# find the counter where we drop the minimal number of examples
closest_key = None
min_diff = 100
for key in valid_keys:
diff = abs(key[1] - pos_num)
if diff < min_diff:
closest_key = key
min_diff = diff
if closest_key is None:
return pos_num, neg_num
return closest_key[1], closest_key[0] - closest_key[1]
def update_true_pos_rate(self, pos_num, total_num):
if total_num == 0:
return
self.bins[(total_num, pos_num)] += 1
self.pos_rates.append(pos_num / total_num)
def report(self):
if len(self.pos_rates) % 1000 != 0:
return
for i in range(1, self.max_length + 1):
print("length", i, sum([self.bins[(i, j)] for j in range(0, i + 1)]))
for j in range(0, i + 1):
print(" pos", j, " ", self.bins[(i, j)])
print("\n\n")
'''
import matplotlib.pyplot as plt
# drop a histogram
plt.hist(data, bins = 10)
plt.show()
'''