|
import random |
|
from math import sqrt |
|
|
|
|
|
class NaiveDB: |
|
def __init__(self): |
|
self.verbose = False |
|
self.init_db() |
|
|
|
def init_db(self): |
|
if self.verbose: |
|
print("call init_db") |
|
self.stories = [] |
|
self.norms = [] |
|
self.vecs = [] |
|
self.flags = [] |
|
self.metas = [] |
|
self.last_search_ids = [] |
|
|
|
def build_db(self, stories, vecs, flags=None, metas=None): |
|
self.stories = stories |
|
self.vecs = vecs |
|
self.flags = flags if flags else [True for _ in self.stories] |
|
self.metas = metas if metas else [{} for _ in self.stories] |
|
self.recompute_norm() |
|
|
|
def save(self, file_path): |
|
print( |
|
"warning! directly save folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead") |
|
|
|
def load(self, file_path): |
|
print( |
|
"warning! directly load folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead") |
|
|
|
def recompute_norm(self): |
|
|
|
|
|
self.norms = [sqrt(sum([x ** 2 for x in vec])) for vec in self.vecs] |
|
|
|
def get_stories_with_id(self, ids): |
|
return [self.stories[i] for i in ids] |
|
|
|
def clean_flag(self): |
|
self.flags = [True for _ in self.stories] |
|
|
|
def disable_story_with_ids(self, close_ids): |
|
for id in close_ids: |
|
self.flags[id] = False |
|
|
|
def close_last_search(self): |
|
for id in self.last_search_ids: |
|
self.flags[id] = False |
|
|
|
def search(self, query_vector, n_results): |
|
|
|
if self.verbose: |
|
print("call search") |
|
|
|
if len(self.norms) != len(self.vecs): |
|
self.recompute_norm() |
|
|
|
|
|
query_norm = sqrt(sum([x ** 2 for x in query_vector])) |
|
|
|
idxs = list(range(len(self.vecs))) |
|
|
|
|
|
similarities = [] |
|
for vec, norm, idx in zip(self.vecs, self.norms, idxs): |
|
if len(self.flags) == len(self.vecs) and not self.flags[idx]: |
|
continue |
|
|
|
dot_product = sum(q * v for q, v in zip(query_vector, vec)) |
|
if query_norm < 1e-20: |
|
similarities.append((random.random(), idx)) |
|
continue |
|
cosine_similarity = dot_product / (query_norm * norm) |
|
similarities.append((cosine_similarity, idx)) |
|
|
|
|
|
similarities.sort(key=lambda x: x[0], reverse=True) |
|
self.last_search_ids = [x[1] for x in similarities[:n_results]] |
|
|
|
stories_length = len(self.stories) |
|
search_id_range = [(max(0, i-3), min(i+4, stories_length)) |
|
for i in self.last_search_ids] |
|
|
|
top_stories = ["\n".join(self.stories[start:end+1]) |
|
for start, end in search_id_range] |
|
return top_stories |
|
|