File size: 1,835 Bytes
b1f80ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from collections import defaultdict

class ShortTermMemory:
    def __init__(self, window_size=10, decay_rate=0.5):
        self.abstract_entities = defaultdict(int)
        self.locations = defaultdict(int)
        self.times = defaultdict(int)
        self.window_size = window_size
        self.decay_rate = decay_rate

    def update(self, entity_type, entity):
        # Determine the appropriate dictionary based on the entity type
        if entity_type == 'abstract':
            entity_dict = self.abstract_entities
        elif entity_type == 'location':
            entity_dict = self.locations
        elif entity_type == 'time':
            entity_dict = self.times
        else:
            raise ValueError(f'Invalid entity type: {entity_type}')

        # Increment the count for the given entity
        entity_dict[entity] += 1

        # Decay the counts of other entities in the same dictionary
        for e, count in list(entity_dict.items()):
            if e != entity:
                entity_dict[e] = int(count * self.decay_rate)

        # Remove entities with count <= 1
        entity_dict = {e: count for e, count in entity_dict.items() if count > 1}

        # Trim the dictionary to the window size
        entity_dict = dict(sorted(entity_dict.items(), key=lambda x: x[1], reverse=True)[:self.window_size])

        # Update the appropriate dictionary with the trimmed version
        if entity_type == 'abstract':
            self.abstract_entities = entity_dict
        elif entity_type == 'location':
            self.locations = entity_dict
        elif entity_type == 'time':
            self.times = entity_dict

    def get_memory(self):
        return {
            'abstract_entities': self.abstract_entities,
            'locations': self.locations,
            'times': self.times
        }