|
""" |
|
The atomic unit of a memory is an event. |
|
A collection of events form an EventCollection. |
|
""" |
|
|
|
import json |
|
from typing import List, Tuple, Union |
|
from .common import Question, ParsedAnswer |
|
from bson import ObjectId |
|
import datetime |
|
import time |
|
|
|
|
|
def file_location_type(path:str) -> str: |
|
""" given a type, it returns a string indicating whether this is `local`, or `url` |
|
""" |
|
if path.startswith('http://') or path.startswith('https://'): |
|
return 'url' |
|
return 'local' |
|
|
|
def read_event_file(file_location:str) -> dict: |
|
""" Event files loaded as json |
|
""" |
|
with open(file_location, 'r') as f: |
|
values = json.load(f) |
|
return values |
|
|
|
def write_event_file(file_location:str, values: dict): |
|
""" Event files stored as json. |
|
""" |
|
prepare_dir_of(file_location) |
|
if file_location_type(file_location) == 'local': |
|
with open(file_location, 'w') as f: |
|
json.dump(values,f) |
|
else: |
|
raise NotImplementedError(f'Uncaught file location type {fl_type}') |
|
return file_location |
|
|
|
|
|
class EventCollection(object): |
|
""" Class used for data processing from event files |
|
""" |
|
def __init__(self): |
|
self.id = str(ObjectId()) |
|
self.events = [] |
|
|
|
def load_from_session(self, event_dir:str, session_token:str) -> "EventCollection": |
|
if file_location_type (event_dir) == 'local': |
|
assert os.path.exists(os.path.join(event_dir, session_token)), f'{os.path.join(event_dir, session_token)} does not exist' |
|
|
|
for event_file in os.listdir(os.path.join(event_dir, session_token)): |
|
|
|
new_event = Event().load_from_event_file(os.path.join(event_dir, session_token, event_file)) |
|
|
|
|
|
|
|
self.add_event(new_event) |
|
else: |
|
raise NotImplementedError(f'Uncaught file location {file_location_type(event_dir)}') |
|
|
|
self.time_sorted() |
|
return self |
|
|
|
def add_event(self, event:"Event") -> "EventCollection": |
|
assert isinstance(event, Event), f"invalid type {type(event)} for EventCollection" |
|
self.events.append(event) |
|
return self |
|
|
|
def time_sorted(self) -> "EventCollection": |
|
self.events = sorted(self.events) |
|
return self |
|
|
|
def __str__(self) -> str: |
|
printout = "" |
|
for ev in sorted(self.events): |
|
printout += str(ev)+'\n' |
|
return printout |
|
|
|
def __len__(self) -> int: |
|
return len(self.events) |
|
|
|
def filter_to(self, label:Union[str, list[str]]): |
|
if isinstance(label, str): |
|
label = [label] |
|
|
|
permissible_classes = [] |
|
for l in label: |
|
assert l in TYPE2CLASS, f"unknown type {label}" |
|
permissible_classes.append(TYPE2CLASS[l]) |
|
|
|
filtered_events = [ev for ev in self.events if isinstance(ev, tuple(permissible_classes))] |
|
return filtered_events |
|
|
|
|
|
class Event(object): |
|
def __init__(self, session_token:str = None): |
|
self.session_token = session_token |
|
self.timestamp = str(datetime.datetime.now()) |
|
self.type = 'EVENT' |
|
|
|
|
|
def __lt__(self, obj: "Event") -> bool: |
|
if not isinstance(obj, Event): raise ValueError(f'Invalid comparison of type {type(obj)}') |
|
return self.timestamp < obj.timestamp |
|
def __gt__(self, obj: "Event") -> bool: |
|
if not isinstance(obj, Event): raise ValueError(f'Invalid comparison of type {type(obj)}') |
|
return self.timestamp > obj.timestamp |
|
def __le__(self, obj: "Event") -> bool: |
|
if not isinstance(obj, Event): raise ValueError(f'Invalid comparison of type {type(obj)}') |
|
return self.timestamp <= obj.timestamp |
|
def __ge__(self, obj: "Event") -> bool: |
|
if not isinstance(obj, Event): raise ValueError(f'Invalid comparison of type {type(obj)}') |
|
return self.timestamp >= obj.timestamp |
|
def __eq__(self, obj: "Event") -> bool: |
|
if not isinstance(obj, Event): raise ValueError(f'Invalid comparison of type {type(obj)}') |
|
return self.timestamp == obj.timestamp |
|
|
|
def load_from_event_params(self, **kwargs) -> "Event": |
|
for key in kwargs: |
|
setattr(self, key, kwargs[key]) |
|
return self |
|
|
|
def load_from_event_file(self, filepath: str) -> "Event": |
|
fields = read_event_file(filepath) |
|
return TYPE2CLASS[fields['type']]().load_from_event_params(**fields) |
|
|
|
def export(self) -> dict: |
|
return vars(self) |
|
|
|
def save_to_event_file(self, filepath: str) -> str: |
|
attr_dict = self.export() |
|
if hasattr(self, 'latency') and self.latency is None: |
|
raise ValueError("Forgot to call .tick and .tock?") |
|
write_event_file(filepath, attr_dict) |
|
return filepath |
|
|
|
@property |
|
def description(self) -> str: |
|
raise NotImplementedError |
|
|
|
def __str__(self) -> str: |
|
return f'{str(self.timestamp)}: [{self.type}]\n{self.description}' |
|
|
|
|
|
class ThinkEvent(Event): |
|
def __init__(self, session_token:str = None, qa_sequence:List[Tuple[Question, ParsedAnswer]]=None): |
|
super().__init__(session_token) |
|
assert isinstance(qa_sequence, list), "qa_sequence should be least, even if len 1." |
|
assert all([isinstance(el, tuple) for el in qa_sequence]), "Elements of qa_sequence should be tuple." |
|
|
|
self.qa_sequence = qa_sequence |
|
|
|
@property |
|
def description(self) -> str: |
|
st = "" |
|
for idx, qa in enumerate(qa_sequence): |
|
q, a = qa |
|
st += f"#### Question {idx}\n" |
|
st += str(q) |
|
st += "\n" |
|
st += f"#### Answer{idx}\n" |
|
st += str(a) |
|
st += "\n" |
|
return st |
|
|
|
class ActEvent(Event): |
|
def __init__(self, session_token:str = None): |
|
super().__init__(session_token) |
|
|
|
@property |
|
def description(self) -> str: |
|
raise NotImplementedError |
|
|
|
class ActErrorEvent(Event): |
|
""" Event of an error in the execution of an action. |
|
""" |
|
def __init__(self, session_token:str = None, |
|
exception:Exception=None): |
|
super().__init__(session_token) |
|
self.exception = exception |
|
|
|
@property |
|
def description(self) -> str: |
|
raise NotImplementedError |
|
|
|
|
|
class ObserveEvent(Event): |
|
def __init__(self, session_token:str = None): |
|
super().__init__(session_token) |
|
|
|
@property |
|
def description(self) -> str: |
|
raise NotImplementedError |
|
|
|
class EvaluateEvent(Event): |
|
def __init__(self, session_token:str = None, |
|
completion_question:Question = None, |
|
completion_eval:ParsedAnswer = None): |
|
super().__init__(session_token) |
|
assert isinstance(completion_question, Question) |
|
assert isinstance(completion_eval, ParsedAnswer) |
|
self.completion_question = completion_question |
|
self.completion_eval = completion_eval |
|
|
|
@property |
|
def description(self) -> str: |
|
st = "## Evaluation Question\n" |
|
st += "\n" |
|
st += str(self.completion_question) |
|
st += "\n" |
|
st += "## Evaluation Answer\n" |
|
st += "\n" |
|
st += str(self.completion_eval) |
|
st += "\n" |
|
return st |
|
|
|
class FeedbackEvent(Event): |
|
def __init__(self, session_token:str = None, |
|
feedback:Union[None, Question]=None): |
|
super().__init__(session_token) |
|
assert isinstance(feedback, Question) or feedback is None |
|
self.feedback = feedback |
|
|
|
@property |
|
def description(self) -> str: |
|
st = "## Reflection:\n" |
|
st += str(self.feedback) |
|
st += "\n" |
|
return st |
|
|
|
|
|
class InteractEvent(Event): |
|
def __init__(self, session_token:str = None): |
|
super().__init__(session_token) |
|
|
|
@property |
|
def description(self) -> str: |
|
raise NotImplementedError |
|
|
|
|
|
TYPE2CLASS = { |
|
'EVENT': Event, |
|
|
|
'THINK': ThinkEvent, |
|
'ACT': ActEvent, |
|
'ACTERROR': ActErrorEvent, |
|
'OBSERVE': ObserveEvent, |
|
'EVALUATE': EvaluateEvent, |
|
'FEEDBACK': FeedbackEvent, |
|
'INTERACT': InteractEvent, |
|
} |