欧卫
'add_app_files'
58627fa
raw
history blame
No virus
1.62 kB
from baleen.utils.loaders import *
from baleen.condenser.condense import Condenser
class Baleen:
def __init__(self, collectionX_path: str, searcher, condenser: Condenser):
self.collectionX = load_collectionX(collectionX_path)
self.searcher = searcher
self.condenser = condenser
def search(self, query, num_hops, depth=100, verbose=False):
assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}."
k = depth // num_hops
searcher = self.searcher
condenser = self.condenser
collectionX = self.collectionX
facts = []
stage1_preds = None
context = None
pids_bag = set()
for hop_idx in range(0, num_hops):
ranking = list(zip(*searcher.search(query, context=context, k=depth)))
ranking_ = []
facts_pids = set([pid for pid, _ in facts])
for pid, rank, score in ranking:
# print(f'[{score}] \t\t {searcher.collection[pid]}')
if len(ranking_) < k and pid not in facts_pids:
ranking_.append(pid)
if len(pids_bag) < k * (hop_idx+1):
pids_bag.add(pid)
stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])
assert len(pids_bag) == depth
return stage2_L3x, pids_bag, stage1_preds