Update MemoryReadingAtomicFlow.py
Browse files
MemoryReadingAtomicFlow.py
CHANGED
@@ -15,13 +15,15 @@ class MemoryReadingAtomicFlow(AtomicFlow):
|
|
15 |
{"plan": "examples/JARVIS/plan.txt"}
|
16 |
"""
|
17 |
|
18 |
-
def __init__(self):
|
19 |
-
super().__init__()
|
20 |
self.supported_mem_name = ["plan", "logs", "code_library"]
|
21 |
|
22 |
def _check_input_data(self, input_data: Dict[str, Any]):
|
23 |
"""input data sanity check"""
|
24 |
-
|
|
|
|
|
25 |
assert mem_name in self.supported_mem_name, (f"{mem_name} is not supported in MemoryReadingAtomicFlow, "
|
26 |
f"supported names are: {self.supported_mem_name}")
|
27 |
assert os.path.exists(mem_path), f"{mem_path} does not exist."
|
@@ -50,7 +52,7 @@ class MemoryReadingAtomicFlow(AtomicFlow):
|
|
50 |
input_data: Dict[str, Any]):
|
51 |
self._check_input_data(input_data)
|
52 |
response = {}
|
53 |
-
for mem_name, mem_path in input_data.items():
|
54 |
if mem_name in ['plan', 'logs']:
|
55 |
response[mem_name] = self._read_text(mem_path)
|
56 |
elif mem_name == 'code_library' and mem_path.endswith('.py'):
|
|
|
15 |
{"plan": "examples/JARVIS/plan.txt"}
|
16 |
"""
|
17 |
|
18 |
+
def __init__(self, **kwargs):
|
19 |
+
super().__init__(**kwargs)
|
20 |
self.supported_mem_name = ["plan", "logs", "code_library"]
|
21 |
|
22 |
def _check_input_data(self, input_data: Dict[str, Any]):
|
23 |
"""input data sanity check"""
|
24 |
+
assert "memory_files" in input_data, "memory_files not passed to MemoryReadingAtomicFlow"
|
25 |
+
|
26 |
+
for mem_name, mem_path in input_data["memory_files"].items():
|
27 |
assert mem_name in self.supported_mem_name, (f"{mem_name} is not supported in MemoryReadingAtomicFlow, "
|
28 |
f"supported names are: {self.supported_mem_name}")
|
29 |
assert os.path.exists(mem_path), f"{mem_path} does not exist."
|
|
|
52 |
input_data: Dict[str, Any]):
|
53 |
self._check_input_data(input_data)
|
54 |
response = {}
|
55 |
+
for mem_name, mem_path in input_data["memory_files"].items():
|
56 |
if mem_name in ['plan', 'logs']:
|
57 |
response[mem_name] = self._read_text(mem_path)
|
58 |
elif mem_name == 'code_library' and mem_path.endswith('.py'):
|