CarlosMalaga's picture
Upload 201 files
2f044c1 verified
import logging
import os
from pathlib import Path
from typing import List, Union
import psutil
import torch
from relik.common.utils import is_package_available
from relik.inference.annotator import Relik
if not is_package_available("fastapi"):
raise ImportError(
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
)
from fastapi import FastAPI, HTTPException
if not is_package_available("ray"):
raise ImportError(
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
)
from ray import serve
from relik.common.log import get_logger
from relik.inference.serve.backend.utils import (
RayParameterManager,
ServerParameterManager,
)
logger = get_logger(__name__, level=logging.INFO)
VERSION = {} # type: ignore
with open(
Path(__file__).parent.parent.parent.parent / "version.py", "r"
) as version_file:
exec(version_file.read(), VERSION)
# Env variables for server
SERVER_MANAGER = ServerParameterManager()
RAY_MANAGER = RayParameterManager()
app = FastAPI(
title="ReLiK",
version=VERSION["VERSION"],
description="ReLiK REST API",
)
@serve.deployment(
ray_actor_options={
"num_gpus": RAY_MANAGER.num_gpus
if (
SERVER_MANAGER.device == "cuda"
or SERVER_MANAGER.retriever_device == "cuda"
or SERVER_MANAGER.reader_device == "cuda"
)
else 0
},
autoscaling_config={
"min_replicas": RAY_MANAGER.min_replicas,
"max_replicas": RAY_MANAGER.max_replicas,
},
)
@serve.ingress(app)
class RelikServer:
def __init__(
self,
relik_pretrained: str | None = None,
device: str = "cpu",
retriever_device: str | None = None,
document_index_device: str | None = None,
reader_device: str | None = None,
precision: str | int | torch.dtype = 32,
retriever_precision: str | int | torch.dtype | None = None,
document_index_precision: str | int | torch.dtype | None = None,
reader_precision: str | int | torch.dtype | None = None,
annotation_type: str = "char",
retriever_batch_size: int = 32,
reader_batch_size: int = 32,
relik_config_override: dict | None = None,
**kwargs,
):
num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False))
torch.set_num_threads(num_threads)
logger.info(f"Torch is running on {num_threads} threads.")
# parameters
logger.info(f"RELIK_PRETRAINED: {relik_pretrained}")
self.relik_pretrained = relik_pretrained
if relik_config_override is None:
relik_config_override = {}
logger.info(f"RELIK_CONFIG_OVERRIDE: {relik_config_override}")
self.relik_config_override = relik_config_override
logger.info(f"DEVICE: {device}")
self.device = device
if retriever_device is not None:
logger.info(f"RETRIEVER_DEVICE: {retriever_device}")
self.retriever_device = retriever_device or device
if document_index_device is not None:
logger.info(f"INDEX_DEVICE: {document_index_device}")
self.document_index_device = document_index_device or retriever_device
if reader_device is not None:
logger.info(f"READER_DEVICE: {reader_device}")
self.reader_device = reader_device
logger.info(f"PRECISION: {precision}")
self.precision = precision
if retriever_precision is not None:
logger.info(f"RETRIEVER_PRECISION: {retriever_precision}")
self.retriever_precision = retriever_precision or precision
if document_index_precision is not None:
logger.info(f"INDEX_PRECISION: {document_index_precision}")
self.document_index_precision = document_index_precision or precision
if reader_precision is not None:
logger.info(f"READER_PRECISION: {reader_precision}")
self.reader_precision = reader_precision or precision
logger.info(f"ANNOTATION_TYPE: {annotation_type}")
self.annotation_type = annotation_type
self.relik = Relik.from_pretrained(
self.relik_pretrained,
device=self.device,
retriever_device=self.retriever_device,
document_index_device=self.document_index_device,
reader_device=self.reader_device,
precision=self.precision,
retriever_precision=self.retriever_precision,
document_index_precision=self.document_index_precision,
reader_precision=self.reader_precision,
**self.relik_config_override,
)
self.retriever_batch_size = retriever_batch_size
self.reader_batch_size = reader_batch_size
# @serve.batch()
async def handle_batch(self, text: List[str]) -> List:
return self.relik(
text,
annotation_type=self.annotation_type,
retriever_batch_size=self.retriever_batch_size,
reader_batch_size=self.reader_batch_size,
)
@app.post("/api/relik")
async def relik_endpoint(self, text: Union[str, List[str]]):
try:
# get predictions for the retriever
return await self.handle_batch(text)
except Exception as e:
# log the entire stack trace
logger.exception(e)
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
server = RelikServer.bind(**vars(SERVER_MANAGER))