File size: 2,755 Bytes
947767a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import datetime
import json
import os
import uuid
import numpy as np
import cattrs
from attrs import define, field
import logging
import tempfile

from settings import Settings

logger = logging.getLogger(__name__)

@define
class Session:
    """A class to store the all information about a session, including the chat
    history and the current scene."""

    session_id: str
    start_time: str
    scene: str
    chat_history_for_display: list[tuple]
    chat_counter: int
    output_dir: str

    @classmethod
    def create(cls):
        return cls.create_for_scene(Settings.default_scene)

    @classmethod
    def create_for_scene(cls, scene: str):
        output_dir = tempfile.mkdtemp()  # Create a temporary directory
        session = cls(
            session_id=str(uuid.uuid4()),
            start_time=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            scene=scene,
            chat_history_for_display=[(None, Settings.INITIAL_MSG_FOR_DISPLAY)],
            chat_counter=0,
            output_dir=output_dir
        )
        logger.info(
            f"Creating a new session {session.session_id} with scene {session.scene} and output dir {session.output_dir}."
        )
        return session

    def convert_float32(self, obj):
        """Convert all np.float32 values in the given object to Python float."""
        if isinstance(obj, np.float32):
            return float(obj)

        if isinstance(obj, list):
            return [self.convert_float32(item) for item in obj]

        if isinstance(obj, tuple):
            return tuple(self.convert_float32(item) for item in obj)

        if isinstance(obj, dict):
            return {key: self.convert_float32(value) for key, value in obj.items()}

        return obj
    
    def get_session_output_dir(self):
        # Create the directory and any parent directories if they don't exist
        os.makedirs(os.path.join(self.output_dir, self.scene, self.session_id), exist_ok=True)
        return os.path.join(self.output_dir, self.scene, self.session_id)

    def save(self) -> None:
        """Save the session as a json file."""
        logger.info(f"Saving session {self.session_id} to disk.")

        structured_data = cattrs.unstructure(self)
        structured_data.pop("chat_history_for_display", None)
        # Convert all np.float32 to float
        converted_data = self.convert_float32(structured_data)

        with open(
            os.path.join(
                self.get_session_output_dir(), f"{self.session_id}.json"
            ),
            "w",
            encoding="utf-8",
        ) as file:
            json.dump(converted_data, file, indent=4)
        logger.info(f"Session {self.session_id} saved to {self.get_session_output_dir()}.")