Tachi67 commited on
Commit
10d8b43
·
1 Parent(s): 6fbf1d8

Update MemoryReadingAtomicFlow.py

Browse files
Files changed (1) hide show
  1. MemoryReadingAtomicFlow.py +6 -4
MemoryReadingAtomicFlow.py CHANGED
@@ -15,13 +15,15 @@ class MemoryReadingAtomicFlow(AtomicFlow):
15
  {"plan": "examples/JARVIS/plan.txt"}
16
  """
17
 
18
- def __init__(self):
19
- super().__init__()
20
  self.supported_mem_name = ["plan", "logs", "code_library"]
21
 
22
  def _check_input_data(self, input_data: Dict[str, Any]):
23
  """input data sanity check"""
24
- for mem_name, mem_path in input_data.items():
 
 
25
  assert mem_name in self.supported_mem_name, (f"{mem_name} is not supported in MemoryReadingAtomicFlow, "
26
  f"supported names are: {self.supported_mem_name}")
27
  assert os.path.exists(mem_path), f"{mem_path} does not exist."
@@ -50,7 +52,7 @@ class MemoryReadingAtomicFlow(AtomicFlow):
50
  input_data: Dict[str, Any]):
51
  self._check_input_data(input_data)
52
  response = {}
53
- for mem_name, mem_path in input_data.items():
54
  if mem_name in ['plan', 'logs']:
55
  response[mem_name] = self._read_text(mem_path)
56
  elif mem_name == 'code_library' and mem_path.endswith('.py'):
 
15
  {"plan": "examples/JARVIS/plan.txt"}
16
  """
17
 
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
  self.supported_mem_name = ["plan", "logs", "code_library"]
21
 
22
  def _check_input_data(self, input_data: Dict[str, Any]):
23
  """input data sanity check"""
24
+ assert "memory_files" in input_data, "memory_files not passed to MemoryReadingAtomicFlow"
25
+
26
+ for mem_name, mem_path in input_data["memory_files"].items():
27
  assert mem_name in self.supported_mem_name, (f"{mem_name} is not supported in MemoryReadingAtomicFlow, "
28
  f"supported names are: {self.supported_mem_name}")
29
  assert os.path.exists(mem_path), f"{mem_path} does not exist."
 
52
  input_data: Dict[str, Any]):
53
  self._check_input_data(input_data)
54
  response = {}
55
+ for mem_name, mem_path in input_data["memory_files"].items():
56
  if mem_name in ['plan', 'logs']:
57
  response[mem_name] = self._read_text(mem_path)
58
  elif mem_name == 'code_library' and mem_path.endswith('.py'):