File size: 4,613 Bytes
9041389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations
from typing import List, Type, Protocol, TypeVar, Dict, Set
import os
import json
import uuid

from domain.domain_protocol import DomainProtocol

DomainT = TypeVar('DomainT', bound=DomainProtocol)
MAP_BIN = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), ".bin", "maps")


class DomainDAO(Protocol[DomainT]):

    def insert(self, domain_objs: List[DomainT]):
        ...

    def read_by_id(self, domain_id: str) -> DomainT:
        ...

    def read_all(self) -> Set[DomainT]:
        ...


class InMemDomainDAO(DomainDAO[DomainT]):

    _id_to_domain_obj: Dict[str, DomainT]

    def __init__(self):
        self._id_to_domain_obj = {}

    def insert(self, domain_objs: List[DomainT]):
        new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs}
        if len(new_id_to_domain_obj) != len(domain_objs):
            raise ValueError("Duplicate IDs exist within incoming domain_objs")
        if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()):
            raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}")
        self._id_to_domain_obj.update(new_id_to_domain_obj)

    def read_by_id(self, domain_id: str) -> DomainT:
        if domain_obj := self._id_to_domain_obj.get(domain_id):
            return domain_obj
        raise ValueError(f"Domain obj with id {domain_id} not found")

    def read_all(self) -> Set[DomainT]:
        return set(self._id_to_domain_obj.values())

    @classmethod
    def load_from_file(cls, file_path: str, domain_cls: Type[DomainT]) -> InMemDomainDAO[DomainT]:
        if not os.path.isfile(file_path):
            raise ValueError(f"File not found: {file_path}")
        with open(file_path, 'r') as f:
            domain_objs = [domain_cls.from_json(line) for line in f]
        dao = cls()
        dao.insert(domain_objs)
        return dao

    def save_to_file(self, file_path: str):
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        domain_jsons = [domain_obj.to_json() for domain_obj in self._id_to_domain_obj.values()]
        with open(file_path, 'w') as f:
            f.write('\n'.join(domain_jsons) + '\n')


class CacheDomainDAO(DomainDAO[DomainT]):

    _id_to_domain_obj: Dict[str, DomainT]
    _save_path: str

    def __init__(self, save_path: str, domain_cls: Type[DomainT]):
        self._id_to_domain_obj = {}
        self._save_path = os.path.join(MAP_BIN, save_path)
        self._load_cache(domain_cls)

    def __enter__(self):
        return self

    def __call__(self, element: DomainT) -> DomainT:
        self.insert([element])
        return element

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._save_cache()

    def set(self, element: DomainT) -> uuid.UUID:
        id = uuid.uuid4()
        self._id_to_domain_obj[str(id)] = element
        self._save_cache()
        return id

    def _save_cache(self):
        os.makedirs(MAP_BIN, exist_ok=True)
        cache = {}
        if os.path.isfile(self._save_path):
            with open(self._save_path, 'r') as f:
                cache = json.load(f)
        domain_json_map = {
            id: domain_obj.to_json()
            for id, domain_obj in self._id_to_domain_obj.items()
        }
        cache.update(domain_json_map)
        with open(self._save_path, 'w') as f:
            json.dump(cache, f, indent=4)

    def _load_cache(self, domain_cls: Type[DomainT]):
        if not os.path.isfile(self._save_path):
            return
        with open(self._save_path, 'r') as f:
            domain_json_map = json.load(f)
        for id, domain_json in domain_json_map.items():
            self._id_to_domain_obj[id] = domain_cls.from_json(domain_json)

    def read_by_id(self, domain_id: str) -> DomainT:
        if domain_obj := self._id_to_domain_obj.get(domain_id):
            return domain_obj
        raise ValueError(f"Domain obj with id {domain_id} not found")

    def read_all(self) -> Set[DomainT]:
        return set(self._id_to_domain_obj.values())

    def insert(self, domain_objs: List[DomainT]):
        new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs}
        if len(new_id_to_domain_obj) != len(domain_objs):
            raise ValueError("Duplicate IDs exist within incoming domain_objs")
        if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()):
            raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}")
        self._id_to_domain_obj.update(new_id_to_domain_obj)
        self._save_cache()