import logging import os import faiss import torch logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) class FaissIndex: def __init__( self, embedding_size=None, faiss_index_location=None, indexer=faiss.IndexFlatIP, ): if embedding_size or faiss_index_location: self.embedding_size = embedding_size else: raise ValueError("Must provide embedding_size") self.faiss_index_location = faiss_index_location if faiss_index_location and os.path.exists(faiss_index_location): self.index = faiss.read_index(faiss_index_location) logger.info(f"Setting embedding size ({self.index.d}) to match saved index") self.embedding_size = self.index.d if os.path.exists(faiss_index_location + ".ids"): with open(faiss_index_location + ".ids") as f: self.id_list = f.read().split("\n") elif self.index.ntotal > 0: raise ValueError("Index file exists but ids file does not") else: self.id_list = [] else: os.makedirs(os.path.dirname(faiss_index_location), exist_ok=True) self.index = None self.indexer = indexer self.id_list = [] def faiss_init(self): index = self.indexer(self.embedding_size) if self.faiss_index_location: faiss.write_index(index, self.faiss_index_location) self.index = index def add(self, inputs, ids, normalize=True): if not self.index: self.faiss_init() if normalize: faiss.normalize_L2(inputs) self.index.add(inputs) self.id_list.extend(ids) faiss.write_index(self.index, self.faiss_index_location) with open(self.faiss_index_location + ".ids", "a") as f: f.write("\n".join(ids) + "\n") def search(self, embedding, k=10, normalize=True): if len(embedding.shape): embedding = embedding.reshape(1, -1) if normalize: faiss.normalize_L2(embedding) D, I = self.index.search(embedding, k) labels = [self.id_list[i] for i in I.squeeze()] return D, I, labels def reset(self): if self.index: self.index.reset() self.id_list = [] try: os.remove(self.faiss_index_location) os.remove(self.faiss_index_location + ".ids") except FileNotFoundError: pass def __len__(self): if self.index: return self.index.ntotal return 0 class VectorSearch: def __init__(self): self.places = self.load("places") self.objects = self.load("objects") def load(self, index_name): return FaissIndex( faiss_index_location=f"faiss_indices/{index_name}.index", ) def top_places(self, query_vec, k=5): if isinstance(query_vec, torch.Tensor): query_vec = query_vec.detach().numpy() *_, results = self.places.search(query_vec, k=k) return results def top_objects(self, query_vec, k=5): if isinstance(query_vec, torch.Tensor): query_vec = query_vec.detach().numpy() *_, results = self.objects.search(query_vec, k=k) return results def prompt_activities(self, query_vec, k=5, one_shot=False): places = self.top_places(query_vec, k=k) objects = self.top_objects(query_vec, k=k) place_str = f"Places: {', '.join(places)}. " object_str = f"Objects: {', '.join(objects)}. " act_str = "I might be doing these 3 activities: " zs = place_str + object_str + act_str example = ( "Places: kitchen. Objects: coffee maker. " f"{act_str}: eating, making breakfast, grinding coffee.\n " ) fs = example + place_str + object_str + act_str if one_shot: return (zs, fs) return zs, places, objects def prompt_summary(self, state_history: list, k=5): rec_strings = ["Event log:"] for rec in state_history: rec_strings.append( f"Places: {', '.join(rec.places)}. " f"Objects: {', '.join(rec.objects)}. " f"Activities: {', '.join(rec.activities)} " ) question = "How would you summarize these events in a few full sentences? " return "\n".join(rec_strings) + "\n" + question