NetsPresso_QA / scripts /beir /compare_domains.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
No virus
5.51 kB
import argparse
import numpy as np
from pyserini.index.lucene import IndexReader
def index2stats(index_path):
index_reader = IndexReader(index_path)
terms = index_reader.terms()
cf_dict = {}
df_dict = {}
for t in terms:
txt = t.term
df = t.df
cf = t.cf
cf_dict[txt] = int(cf)
df_dict[txt] = int(df)
return cf_dict, df_dict, index_reader.stats()
def count_total(d):
s = 0
for t in d:
s += d[t]
return s
def kl_divergence(d1, d2):
value = float(0)
for w in d1:
if w in d2:
value += d1[w] * np.log(d1[w] / d2[w])
return value
def js_divergence(d1, d2):
mean = {}
for w in d1:
mean[w] = d1[w] * 0.5
for w in d2:
if w in mean:
mean[w] += d2[w] * 0.5
else:
mean[w] = d2[w] * 0.5
jsd = 0.5 * (kl_divergence(d1, mean) + kl_divergence(d2, mean))
return jsd
def jaccard(d1, d2):
ret = (float(len(set(d1).intersection(set(d2)))) /
float(len(set(d1).union(set(d2)))))
return ret
def weighted_jaccard(d1, d2):
term_union = set(d1).union(set(d2))
min_sum = max_sum = 0
for t in term_union:
if t not in d1:
max_sum += d2[t]
elif t not in d2:
max_sum += d1[t]
else:
min_sum += min(d1[t], d2[t])
max_sum += max(d1[t], d2[t])
ret = float(min_sum) / float(max_sum)
return ret
def cf2freq(d):
total = count_total(d)
new_d = {}
for t in d:
new_d[t] = float(d[t]) / float(total)
return new_d
def df2idf(d, n):
total = n
new_d = {}
for t in d:
new_d[t] = float(n) / float(d[t])
return new_d
def filter_freq_dict(freq_d, threshold=0.0001):
new_d = {}
for t in freq_d:
if freq_d[t] > threshold:
new_d[t] = freq_d[t]
return new_d
def print_results(datasets, results, save_file):
f = open(save_file, 'w')
f.write("\t{}\n".format("\t".join(datasets)))
for d1 in datasets:
f.write(d1)
for d2 in datasets:
f.write("\t{:.4f}".format(results[d1][d2]))
f.write("\n")
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--index_path', type=str, help='path to indexes of all the beir dataset', required=True)
parser.add_argument('--index_name_format', type=str, help='define your own index dir path name', default="/lucene-index-beir-{}")
parser.add_argument('--compare_metric', type=str, help='the metric for comparing two vocab, choose from: jaccard, weight_jaccard, df_filter, tf_filter, kl_divergence, js_divergence', default="weight_jaccard")
parser.add_argument('--compare_threshold', type=float, help='when choosing df_filter, or tf_filter, you can choolse the threshold', default=0.0001)
parser.add_argument('--output_path', type=str, help='path to save the stat results', required=True)
args = parser.parse_args()
beir_datasets = ['trec-covid', 'bioasq', 'nfcorpus', 'nq', 'hotpotqa', 'climate-fever', 'fever', 'dbpedia-entity', 'fiqa', 'signal1m', 'trec-news', 'robust04', 'arguana', 'webis-touche2020', 'quora', 'cqadupstack', 'scidocs', 'scifact', 'msmarco']
#beir_datasets = ['arguana', 'fiqa']
cfs = dfs = stats = {}
for d in beir_datasets:
cf, df, stat = index2stats(args.index_path + args.index_name_format.format(d))
cfs[d] = cf # count frequency -- int
dfs[d] = df # document frequency -- int
stat[d] = stat
results = {}
for d1 in beir_datasets:
metric_d1 = {}
for d2 in beir_datasets:
if d1 == d2:
if args.compare_metric in ["jaccard", "weight_jaccard", "df_filter", "tf_filter"]:
metric_d1[d2] = 1
elif args.compare_metric in ["kl_divergence", "js_divergence"]:
metric_d1[d2] = 0
else:
if args.compare_metric == "jaccard":
metric_d1[d2] = jaccard(cfs[d1], cfs[d2])
elif args.compare_metric == "weight_jaccard":
new_d1 = filter_freq_dict(cf2freq(cfs[d1]))
new_d2 = filter_freq_dict(cf2freq(cfs[d2]))
metric_d1[d2] = weighted_jaccard(new_d1, new_d2)
elif args.compare_metric == "df_filter":
new_d1 = filter_freq_dict(cf2freq(cfs[d1]))
new_d2 = filter_freq_dict(cf2freq(cfs[d2]))
metric_d1[d2] = jaccard(new_d1, new_d2)
elif args.compare_metric == "tf_filter":
new_d1 = filter_freq_dict(df2idf(dfs[d1], 1))
new_d2 = filter_freq_dict(df2idf(dfs[d2], 1))
metric_d1[d2] = jaccard(new_d1, new_d2)
elif args.compare_metric == "kl_divergence":
new_d1 = filter_freq_dict(cf2freq(cfs[d1]))
new_d2 = filter_freq_dict(cf2freq(cfs[d2]))
metric_d1[d2] = kl_divergence(new_d1, new_d2)
elif args.compare_metric == "js_divergence":
new_d1 = filter_freq_dict(cf2freq(cfs[d1]))
new_d2 = filter_freq_dict(cf2freq(cfs[d2]))
metric_d1[d2] = js_divergence(new_d1, new_d2)
else:
raise NotImplementedError
results[d1] = metric_d1
print_results(beir_datasets, results, args.output_path)