File size: 1,784 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import json
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import List, Union

from redis import Redis

SUCCESS_STATE = 1
FAILURE_STATE = -1


@contextmanager
def failure_handler(redis: Redis, *request_ids: str):
    """
    Context manager that updates the status/results key in redis with exception
    info on failure.
    """
    try:
        yield
    except Exception as error:
        message = type(error).__name__ + ": " + str(error)
        for request_id in request_ids:
            redis.publish(
                "results",
                json.dumps(
                    {"task_id": request_id, "status": FAILURE_STATE, "payload": message}
                ),
            )
        raise


@contextmanager
def shm_manager(
    *shms: Union[str, shared_memory.SharedMemory], unlink_on_success: bool = False
):
    """Context manager that closes and frees shared memory objects."""
    try:
        loaded_shms = []
        for shm in shms:
            errors = []
            try:
                if isinstance(shm, str):
                    shm = shared_memory.SharedMemory(name=shm)
                loaded_shms.append(shm)
            except BaseException as error:
                errors.append(error)
            if errors:
                raise Exception(errors)

        yield loaded_shms
    except:
        for shm in loaded_shms:
            shm.close()
            shm.unlink()
        raise
    else:
        for shm in loaded_shms:
            shm.close()
            if unlink_on_success:
                shm.unlink()


@dataclass
class SharedMemoryMetadata:
    """Info needed to load array from shared memory"""

    shm_name: str
    array_shape: List[int]
    array_dtype: str