|
from collections import defaultdict |
|
|
|
class ShortTermMemory: |
|
def __init__(self, window_size=10, decay_rate=0.5): |
|
self.abstract_entities = defaultdict(int) |
|
self.locations = defaultdict(int) |
|
self.times = defaultdict(int) |
|
self.window_size = window_size |
|
self.decay_rate = decay_rate |
|
|
|
def update(self, entity_type, entity): |
|
|
|
if entity_type == 'abstract': |
|
entity_dict = self.abstract_entities |
|
elif entity_type == 'location': |
|
entity_dict = self.locations |
|
elif entity_type == 'time': |
|
entity_dict = self.times |
|
else: |
|
raise ValueError(f'Invalid entity type: {entity_type}') |
|
|
|
|
|
entity_dict[entity] += 1 |
|
|
|
|
|
for e, count in list(entity_dict.items()): |
|
if e != entity: |
|
entity_dict[e] = int(count * self.decay_rate) |
|
|
|
|
|
entity_dict = {e: count for e, count in entity_dict.items() if count > 1} |
|
|
|
|
|
entity_dict = dict(sorted(entity_dict.items(), key=lambda x: x[1], reverse=True)[:self.window_size]) |
|
|
|
|
|
if entity_type == 'abstract': |
|
self.abstract_entities = entity_dict |
|
elif entity_type == 'location': |
|
self.locations = entity_dict |
|
elif entity_type == 'time': |
|
self.times = entity_dict |
|
|
|
def get_memory(self): |
|
return { |
|
'abstract_entities': self.abstract_entities, |
|
'locations': self.locations, |
|
'times': self.times |
|
} |
|
|