from collections import defaultdict class ShortTermMemory: def __init__(self, window_size=10, decay_rate=0.5): self.abstract_entities = defaultdict(int) self.locations = defaultdict(int) self.times = defaultdict(int) self.window_size = window_size self.decay_rate = decay_rate def update(self, entity_type, entity): # Determine the appropriate dictionary based on the entity type if entity_type == 'abstract': entity_dict = self.abstract_entities elif entity_type == 'location': entity_dict = self.locations elif entity_type == 'time': entity_dict = self.times else: raise ValueError(f'Invalid entity type: {entity_type}') # Increment the count for the given entity entity_dict[entity] += 1 # Decay the counts of other entities in the same dictionary for e, count in list(entity_dict.items()): if e != entity: entity_dict[e] = int(count * self.decay_rate) # Remove entities with count <= 1 entity_dict = {e: count for e, count in entity_dict.items() if count > 1} # Trim the dictionary to the window size entity_dict = dict(sorted(entity_dict.items(), key=lambda x: x[1], reverse=True)[:self.window_size]) # Update the appropriate dictionary with the trimmed version if entity_type == 'abstract': self.abstract_entities = entity_dict elif entity_type == 'location': self.locations = entity_dict elif entity_type == 'time': self.times = entity_dict def get_memory(self): return { 'abstract_entities': self.abstract_entities, 'locations': self.locations, 'times': self.times }