Fucius's picture
Upload 422 files
df6c67d verified
import json
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import List, Union
from redis import Redis
SUCCESS_STATE = 1
FAILURE_STATE = -1
@contextmanager
def failure_handler(redis: Redis, *request_ids: str):
"""
Context manager that updates the status/results key in redis with exception
info on failure.
"""
try:
yield
except Exception as error:
message = type(error).__name__ + ": " + str(error)
for request_id in request_ids:
redis.publish(
"results",
json.dumps(
{"task_id": request_id, "status": FAILURE_STATE, "payload": message}
),
)
raise
@contextmanager
def shm_manager(
*shms: Union[str, shared_memory.SharedMemory], unlink_on_success: bool = False
):
"""Context manager that closes and frees shared memory objects."""
try:
loaded_shms = []
for shm in shms:
errors = []
try:
if isinstance(shm, str):
shm = shared_memory.SharedMemory(name=shm)
loaded_shms.append(shm)
except BaseException as error:
errors.append(error)
if errors:
raise Exception(errors)
yield loaded_shms
except:
for shm in loaded_shms:
shm.close()
shm.unlink()
raise
else:
for shm in loaded_shms:
shm.close()
if unlink_on_success:
shm.unlink()
@dataclass
class SharedMemoryMetadata:
"""Info needed to load array from shared memory"""
shm_name: str
array_shape: List[int]
array_dtype: str