Spaces:
Sleeping
Sleeping
import random | |
import pdb | |
from collections import defaultdict | |
import numpy | |
import numpy as np | |
import math | |
class PosRateControllerLength(): | |
def __init__(self, max_length = 9, center_length = 8): | |
self.leng_to_controller = [PosRateController() for i in range(max_length + 1)] | |
self.max_length = max_length | |
self.center_length = center_length | |
self.pos_rates = [] | |
self.lengths = [] | |
def __call__(self, pos_num, neg_num): | |
# first sample the query length | |
length = numpy.random.normal(self.center_length, 5.0) | |
# cap to 1 and max_length | |
length = max(1, min(self.max_length, length)) | |
length = round(length) | |
length = min(pos_num + neg_num, length) | |
pos_num, neg_num = self.leng_to_controller[length](pos_num, neg_num, desired_length = length) | |
return pos_num, neg_num | |
def update_true_pos_rate(self, pos_num, total_num): | |
if total_num == 0: | |
return | |
self.pos_rates.append(pos_num / total_num) | |
self.lengths.append(total_num) | |
total_num = int(min(total_num, self.max_length)) | |
self.leng_to_controller[total_num].update_true_pos_rate(pos_num, total_num) | |
# if len(self.pos_rates) % 1000 == 0: | |
# print(self.pos_rates) | |
# print(self.lengths) | |
# for i in range(len(self.leng_to_controller)): | |
# print("length: ", i) | |
# print("overall pos rate: ", sum(self.leng_to_controller[i].pos_rates) / max(1.0, len(self.leng_to_controller[i].pos_rates))) | |
class PosRateController(): | |
def __init__(self, bin_num = 10, adhoc_bin_weights = {}, control_length = -1): | |
self.bins = [1.0 / bin_num * i for i in range(bin_num + 1)] | |
self.bin_counter = [0 for i in range(bin_num + 1)] | |
self.adhoc_bin_weights = adhoc_bin_weights # this is a list of weights for each bin | |
self.slack = 20 # we can allow some slack for the pos rate control | |
self.pos_rates = [] | |
self.lengths = [] | |
def _find_closest_bin(self, pos_rate, valid_bins): | |
valid_bins_rate = [self.bins[i] for i in valid_bins] | |
# determine the pos rate is in which bin | |
# find the closes bin to the current pos rate | |
bin_index = valid_bins[0] | |
min_diff = abs(pos_rate - valid_bins_rate[0]) | |
for i in range(1, len(valid_bins)): | |
diff = abs(pos_rate - valid_bins_rate[i]) | |
if diff < min_diff: | |
bin_index = valid_bins[i] | |
min_diff = diff | |
if diff == min_diff and random.random() > 0.5: | |
bin_index = valid_bins[i] | |
min_diff = diff | |
return bin_index | |
def __call__(self, pos_num, neg_num, desired_length = -1): | |
if pos_num == 0 and neg_num == 0: | |
return 0, 0 | |
if pos_num == 1 and neg_num == 0: | |
return 1, 0 | |
pos_now = pos_num / (pos_num + neg_num) | |
min_bin_counter = min([self.bin_counter[i] * self.adhoc_bin_weights.get(i, 1.0) for i in range(len(self.bin_counter)) ]) | |
valid_bins = [i for i in range(len(self.bin_counter)) if self.bin_counter[i] * self.adhoc_bin_weights.get(i, 1.0) <= min_bin_counter + self.slack] # these are the bins this example could go to | |
bin_index = random.choice(valid_bins) | |
#self._find_closest_bin(pos_now, valid_bins) | |
if desired_length > 0: | |
# control to the desired length | |
desired_pos = round(desired_length * self.bins[bin_index]) | |
pos_num = min(pos_num, desired_pos) | |
if self.bins[bin_index] == 0: | |
neg_num = min(neg_num, desired_length) | |
else: | |
neg_num = min(neg_num, round(pos_num / self.bins[bin_index] * (1 - self.bins[bin_index]))) | |
else: | |
# let's control the pos_rate to the desired rate | |
if pos_now == self.bins[bin_index]: | |
pass | |
elif pos_now < self.bins[bin_index]: | |
# this means we need to drop some negative examples | |
neg_num = round(pos_num / self.bins[bin_index] - pos_num) | |
else: | |
# this means we need to drop some positive examples | |
pos_num = round(neg_num * self.bins[bin_index] / (1 - self.bins[bin_index])) | |
# new_bin_index = self._find_closest_bin(pos_num / (pos_num + neg_num), list(range(len(self.bins)))) | |
# if new_bin_index != bin_index and len(self.pos_rates) > 1000: | |
# pdb.set_trace() | |
# self.bin_counter[new_bin_index] += 1 | |
# self.pos_rates.append(pos_num / (pos_num + neg_num)) | |
# make sure we don't have all 0s | |
if pos_num == 0 and neg_num == 0: | |
pos_num = 1 | |
neg_num = 0 | |
return pos_num, neg_num | |
def update_true_pos_rate(self, pos_num, total_num): | |
if total_num == 0: # ignore | |
return | |
pos_rate = pos_num / total_num | |
bin_index = self._find_closest_bin(pos_rate, list(range(len(self.bins)))) | |
self.bin_counter[bin_index] += 1 | |
self.pos_rates.append(pos_rate) | |
self.lengths.append(total_num) | |
# if len(self.pos_rates) % 1000 == 0: | |
# print(self.pos_rates) | |
# for i in self.pos_rate_by_lengths: | |
# print(i, len(self.pos_rate_by_lengths[i]), sum(self.pos_rate_by_lengths[i]) / len(self.pos_rate_by_lengths[i])) | |
def report(self,): | |
#print(self.lengths) | |
print(np.mean(self.lengths), self.bin_counter) | |
from scipy.stats import norm | |
class PosRateControllerV2(): | |
def __init__(self, max_length, center_length, scale = 4.0): | |
self.max_length = max_length | |
self.center_length = center_length | |
self.bins = defaultdict(int) | |
for i in range(1, max_length + 1): | |
for j in range(0, i + 1): | |
self.bins[(i, j)] = 0 | |
# calculate the weights according to a normal distribution centered on center_length | |
dis = norm(loc = center_length, scale = scale) | |
self.weights = {} | |
for i in range(1, max_length+1): | |
self.weights[i] = dis.cdf(i + 0.5) - dis.cdf(i - 0.5) | |
# print(self.weights) | |
# renormalize the weights | |
total_weight = sum(self.weights.values()) | |
for i in self.weights: | |
self.weights[i] /= total_weight | |
self.weights_pos_rate = {} | |
# do a slight reweight | |
self.pos_rates = [] | |
self.slack = 10 | |
def __call__(self, pos_num, neg_num, max_cap_num = -1): | |
# find the most good matching bin | |
valid_keys = [] | |
for key in self.bins: | |
if key[0] <= pos_num + neg_num and key[1] <= pos_num and key[0] - key[1] <= neg_num and (max_cap_num == -1 or key[0] <= max_cap_num): | |
valid_keys.append(key) | |
# find the min count in the valid keys | |
if len(valid_keys) == 0: | |
print(pos_num, neg_num) | |
return pos_num, neg_num | |
min_counter = min([self.bins[key] / self.weights[key[0]] for key in valid_keys]) | |
valid_keys = [key for key in valid_keys if self.bins[key] / self.weights[key[0]] <= min_counter + self.slack] # rescreened | |
# find the counter where we drop the minimal number of examples | |
closest_key = None | |
min_diff = 100 | |
for key in valid_keys: | |
diff = abs(key[1] - pos_num) | |
if diff < min_diff: | |
closest_key = key | |
min_diff = diff | |
if closest_key is None: | |
return pos_num, neg_num | |
return closest_key[1], closest_key[0] - closest_key[1] | |
def update_true_pos_rate(self, pos_num, total_num): | |
if total_num == 0: | |
return | |
self.bins[(total_num, pos_num)] += 1 | |
self.pos_rates.append(pos_num / total_num) | |
def report(self): | |
if len(self.pos_rates) % 1000 != 0: | |
return | |
for i in range(1, self.max_length + 1): | |
print("length", i, sum([self.bins[(i, j)] for j in range(0, i + 1)])) | |
for j in range(0, i + 1): | |
print(" pos", j, " ", self.bins[(i, j)]) | |
print("\n\n") | |
''' | |
import matplotlib.pyplot as plt | |
# drop a histogram | |
plt.hist(data, bins = 10) | |
plt.show() | |
''' | |