Fucius's picture
Upload 422 files
df6c67d verified
import logging
import time
from asyncio import Queue as AioQueue
from dataclasses import asdict
from multiprocessing import shared_memory
from queue import Queue
from threading import Thread
from typing import Dict, List, Tuple
import numpy as np
import orjson
from redis import ConnectionPool, Redis
from inference.core.entities.requests.inference import (
InferenceRequest,
request_from_type,
)
from inference.core.env import MAX_ACTIVE_MODELS, MAX_BATCH_SIZE, REDIS_HOST, REDIS_PORT
from inference.core.managers.base import ModelManager
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
from inference.core.models.roboflow import RoboflowInferenceModel
from inference.core.registries.roboflow import RoboflowModelRegistry
from inference.enterprise.parallel.tasks import postprocess
from inference.enterprise.parallel.utils import (
SharedMemoryMetadata,
failure_handler,
shm_manager,
)
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger()
from inference.models.utils import ROBOFLOW_MODEL_TYPES
BATCH_SIZE = MAX_BATCH_SIZE
if BATCH_SIZE == float("inf"):
BATCH_SIZE = 32
AGE_TRADEOFF_SECONDS_FACTOR = 30
class InferServer:
def __init__(self, redis: Redis) -> None:
self.redis = redis
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
model_manager = ModelManager(model_registry)
self.model_manager = WithFixedSizeCache(
model_manager, max_size=MAX_ACTIVE_MODELS
)
self.running = True
self.response_queue = Queue()
self.write_thread = Thread(target=self.write_responses)
self.write_thread.start()
self.batch_queue = Queue(maxsize=1)
self.infer_thread = Thread(target=self.infer)
self.infer_thread.start()
def write_responses(self):
while True:
try:
response = self.response_queue.get()
write_infer_arrays_and_launch_postprocess(*response)
except Exception as error:
logger.warning(
f"Encountered error while writiing response:\n" + str(error)
)
def infer_loop(self):
while self.running:
try:
model_names = get_requested_model_names(self.redis)
if not model_names:
time.sleep(0.001)
continue
self.get_batch(model_names)
except Exception as error:
logger.warning("Encountered error in infer loop:\n" + str(error))
continue
def infer(self):
while True:
model_id, images, batch, preproc_return_metadatas = self.batch_queue.get()
outputs = self.model_manager.predict(model_id, images)
for output, b, metadata in zip(
zip(*outputs), batch, preproc_return_metadatas
):
self.response_queue.put_nowait((output, b["request"], metadata))
def get_batch(self, model_names):
start = time.perf_counter()
batch, model_id = get_batch(self.redis, model_names)
logger.info(f"Inferring: model<{model_id}> batch_size<{len(batch)}>")
with failure_handler(self.redis, *[b["request"]["id"] for b in batch]):
self.model_manager.add_model(model_id, batch[0]["request"]["api_key"])
model_type = self.model_manager.get_task_type(model_id)
for b in batch:
request = request_from_type(model_type, b["request"])
b["request"] = request
b["shm_metadata"] = SharedMemoryMetadata(**b["shm_metadata"])
metadata_processed = time.perf_counter()
logger.info(
f"Took {(metadata_processed - start):3f} seconds to process metadata"
)
with shm_manager(
*[b["shm_metadata"].shm_name for b in batch], unlink_on_success=True
) as shms:
images, preproc_return_metadatas = load_batch(batch, shms)
loaded = time.perf_counter()
logger.info(
f"Took {(loaded - metadata_processed):3f} seconds to load batch"
)
self.batch_queue.put(
(model_id, images, batch, preproc_return_metadatas)
)
def get_requested_model_names(redis: Redis) -> List[str]:
request_counts = redis.hgetall("requests")
model_names = [
model_name for model_name, count in request_counts.items() if int(count) > 0
]
return model_names
def get_batch(redis: Redis, model_names: List[str]) -> Tuple[List[Dict], str]:
"""
Run a heuristic to select the best batch to infer on
redis[Redis]: redis client
model_names[List[str]]: list of models with nonzero number of requests
returns:
Tuple[List[Dict], str]
List[Dict] represents a batch of request dicts
str is the model id
"""
batch_sizes = [
RoboflowInferenceModel.model_metadata_from_memcache_endpoint(m)["batch_size"]
for m in model_names
]
batch_sizes = [b if not isinstance(b, str) else BATCH_SIZE for b in batch_sizes]
batches = [
redis.zrange(f"infer:{m}", 0, b - 1, withscores=True)
for m, b in zip(model_names, batch_sizes)
]
model_index = select_best_inference_batch(batches, batch_sizes)
batch = batches[model_index]
selected_model = model_names[model_index]
redis.zrem(f"infer:{selected_model}", *[b[0] for b in batch])
redis.hincrby(f"requests", selected_model, -len(batch))
batch = [orjson.loads(b[0]) for b in batch]
return batch, selected_model
def select_best_inference_batch(batches, batch_sizes):
now = time.time()
average_ages = [np.mean([float(b[1]) - now for b in batch]) for batch in batches]
lengths = [
len(batch) / batch_size for batch, batch_size in zip(batches, batch_sizes)
]
fitnesses = [
age / AGE_TRADEOFF_SECONDS_FACTOR + length
for age, length in zip(average_ages, lengths)
]
model_index = fitnesses.index(max(fitnesses))
return model_index
def load_batch(
batch: List[Dict[str, str]], shms: List[shared_memory.SharedMemory]
) -> Tuple[List[np.ndarray], List[Dict]]:
images = []
preproc_return_metadatas = []
for b, shm in zip(batch, shms):
shm_metadata: SharedMemoryMetadata = b["shm_metadata"]
image = np.ndarray(
shm_metadata.array_shape, dtype=shm_metadata.array_dtype, buffer=shm.buf
).copy()
images.append(image)
preproc_return_metadatas.append(b["preprocess_metadata"])
return images, preproc_return_metadatas
def write_infer_arrays_and_launch_postprocess(
arrs: Tuple[np.ndarray, ...],
request: InferenceRequest,
preproc_return_metadata: Dict,
):
"""Write inference results to shared memory and launch the postprocessing task"""
shms = [shared_memory.SharedMemory(create=True, size=arr.nbytes) for arr in arrs]
with shm_manager(*shms):
shm_metadatas = []
for arr, shm in zip(arrs, shms):
shared = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
shared[:] = arr[:]
shm_metadata = SharedMemoryMetadata(
shm_name=shm.name, array_shape=arr.shape, array_dtype=arr.dtype.name
)
shm_metadatas.append(asdict(shm_metadata))
postprocess.s(
tuple(shm_metadatas), request.dict(), preproc_return_metadata
).delay()
if __name__ == "__main__":
pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
redis = Redis(connection_pool=pool)
InferServer(redis).infer_loop()