Spaces:
Runtime error
Runtime error
import argparse | |
import numpy as np | |
from pyserini.index.lucene import IndexReader | |
def index2stats(index_path): | |
index_reader = IndexReader(index_path) | |
terms = index_reader.terms() | |
cf_dict = {} | |
df_dict = {} | |
for t in terms: | |
txt = t.term | |
df = t.df | |
cf = t.cf | |
cf_dict[txt] = int(cf) | |
df_dict[txt] = int(df) | |
return cf_dict, df_dict, index_reader.stats() | |
def count_total(d): | |
s = 0 | |
for t in d: | |
s += d[t] | |
return s | |
def kl_divergence(d1, d2): | |
value = float(0) | |
for w in d1: | |
if w in d2: | |
value += d1[w] * np.log(d1[w] / d2[w]) | |
return value | |
def js_divergence(d1, d2): | |
mean = {} | |
for w in d1: | |
mean[w] = d1[w] * 0.5 | |
for w in d2: | |
if w in mean: | |
mean[w] += d2[w] * 0.5 | |
else: | |
mean[w] = d2[w] * 0.5 | |
jsd = 0.5 * (kl_divergence(d1, mean) + kl_divergence(d2, mean)) | |
return jsd | |
def jaccard(d1, d2): | |
ret = (float(len(set(d1).intersection(set(d2)))) / | |
float(len(set(d1).union(set(d2))))) | |
return ret | |
def weighted_jaccard(d1, d2): | |
term_union = set(d1).union(set(d2)) | |
min_sum = max_sum = 0 | |
for t in term_union: | |
if t not in d1: | |
max_sum += d2[t] | |
elif t not in d2: | |
max_sum += d1[t] | |
else: | |
min_sum += min(d1[t], d2[t]) | |
max_sum += max(d1[t], d2[t]) | |
ret = float(min_sum) / float(max_sum) | |
return ret | |
def cf2freq(d): | |
total = count_total(d) | |
new_d = {} | |
for t in d: | |
new_d[t] = float(d[t]) / float(total) | |
return new_d | |
def df2idf(d, n): | |
total = n | |
new_d = {} | |
for t in d: | |
new_d[t] = float(n) / float(d[t]) | |
return new_d | |
def filter_freq_dict(freq_d, threshold=0.0001): | |
new_d = {} | |
for t in freq_d: | |
if freq_d[t] > threshold: | |
new_d[t] = freq_d[t] | |
return new_d | |
def print_results(datasets, results, save_file): | |
f = open(save_file, 'w') | |
f.write("\t{}\n".format("\t".join(datasets))) | |
for d1 in datasets: | |
f.write(d1) | |
for d2 in datasets: | |
f.write("\t{:.4f}".format(results[d1][d2])) | |
f.write("\n") | |
f.close() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--index_path', type=str, help='path to indexes of all the beir dataset', required=True) | |
parser.add_argument('--index_name_format', type=str, help='define your own index dir path name', default="/lucene-index-beir-{}") | |
parser.add_argument('--compare_metric', type=str, help='the metric for comparing two vocab, choose from: jaccard, weight_jaccard, df_filter, tf_filter, kl_divergence, js_divergence', default="weight_jaccard") | |
parser.add_argument('--compare_threshold', type=float, help='when choosing df_filter, or tf_filter, you can choolse the threshold', default=0.0001) | |
parser.add_argument('--output_path', type=str, help='path to save the stat results', required=True) | |
args = parser.parse_args() | |
beir_datasets = ['trec-covid', 'bioasq', 'nfcorpus', 'nq', 'hotpotqa', 'climate-fever', 'fever', 'dbpedia-entity', 'fiqa', 'signal1m', 'trec-news', 'robust04', 'arguana', 'webis-touche2020', 'quora', 'cqadupstack', 'scidocs', 'scifact', 'msmarco'] | |
#beir_datasets = ['arguana', 'fiqa'] | |
cfs = dfs = stats = {} | |
for d in beir_datasets: | |
cf, df, stat = index2stats(args.index_path + args.index_name_format.format(d)) | |
cfs[d] = cf # count frequency -- int | |
dfs[d] = df # document frequency -- int | |
stat[d] = stat | |
results = {} | |
for d1 in beir_datasets: | |
metric_d1 = {} | |
for d2 in beir_datasets: | |
if d1 == d2: | |
if args.compare_metric in ["jaccard", "weight_jaccard", "df_filter", "tf_filter"]: | |
metric_d1[d2] = 1 | |
elif args.compare_metric in ["kl_divergence", "js_divergence"]: | |
metric_d1[d2] = 0 | |
else: | |
if args.compare_metric == "jaccard": | |
metric_d1[d2] = jaccard(cfs[d1], cfs[d2]) | |
elif args.compare_metric == "weight_jaccard": | |
new_d1 = filter_freq_dict(cf2freq(cfs[d1])) | |
new_d2 = filter_freq_dict(cf2freq(cfs[d2])) | |
metric_d1[d2] = weighted_jaccard(new_d1, new_d2) | |
elif args.compare_metric == "df_filter": | |
new_d1 = filter_freq_dict(cf2freq(cfs[d1])) | |
new_d2 = filter_freq_dict(cf2freq(cfs[d2])) | |
metric_d1[d2] = jaccard(new_d1, new_d2) | |
elif args.compare_metric == "tf_filter": | |
new_d1 = filter_freq_dict(df2idf(dfs[d1], 1)) | |
new_d2 = filter_freq_dict(df2idf(dfs[d2], 1)) | |
metric_d1[d2] = jaccard(new_d1, new_d2) | |
elif args.compare_metric == "kl_divergence": | |
new_d1 = filter_freq_dict(cf2freq(cfs[d1])) | |
new_d2 = filter_freq_dict(cf2freq(cfs[d2])) | |
metric_d1[d2] = kl_divergence(new_d1, new_d2) | |
elif args.compare_metric == "js_divergence": | |
new_d1 = filter_freq_dict(cf2freq(cfs[d1])) | |
new_d2 = filter_freq_dict(cf2freq(cfs[d2])) | |
metric_d1[d2] = js_divergence(new_d1, new_d2) | |
else: | |
raise NotImplementedError | |
results[d1] = metric_d1 | |
print_results(beir_datasets, results, args.output_path) | |