cleopatro's picture
added main and stm
b1f80ab
raw
history blame contribute delete
No virus
1.84 kB
from collections import defaultdict
class ShortTermMemory:
def __init__(self, window_size=10, decay_rate=0.5):
self.abstract_entities = defaultdict(int)
self.locations = defaultdict(int)
self.times = defaultdict(int)
self.window_size = window_size
self.decay_rate = decay_rate
def update(self, entity_type, entity):
# Determine the appropriate dictionary based on the entity type
if entity_type == 'abstract':
entity_dict = self.abstract_entities
elif entity_type == 'location':
entity_dict = self.locations
elif entity_type == 'time':
entity_dict = self.times
else:
raise ValueError(f'Invalid entity type: {entity_type}')
# Increment the count for the given entity
entity_dict[entity] += 1
# Decay the counts of other entities in the same dictionary
for e, count in list(entity_dict.items()):
if e != entity:
entity_dict[e] = int(count * self.decay_rate)
# Remove entities with count <= 1
entity_dict = {e: count for e, count in entity_dict.items() if count > 1}
# Trim the dictionary to the window size
entity_dict = dict(sorted(entity_dict.items(), key=lambda x: x[1], reverse=True)[:self.window_size])
# Update the appropriate dictionary with the trimmed version
if entity_type == 'abstract':
self.abstract_entities = entity_dict
elif entity_type == 'location':
self.locations = entity_dict
elif entity_type == 'time':
self.times = entity_dict
def get_memory(self):
return {
'abstract_entities': self.abstract_entities,
'locations': self.locations,
'times': self.times
}