e-hossam96's picture
lower expectations for a 2 vCPU instance
39562f3
import gc
import uuid
import logging
import asyncio
import transformers
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import (
AutoTokenizer,
)
from optimum.pipelines import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
# ----------------------------- #
# Configurations #
# ----------------------------- #
transformers.set_seed(42)
MODEL_NAME = "distilroberta-base-climate-sentiment-onnx-quantized"
BATCH_PROCESS_INTERVAL = 0.05
MAX_BATCH_SIZE = 16
# ----------------------------- #
# Shared Storage #
# ----------------------------- #
lock = asyncio.Lock()
query_queue: asyncio.Queue = asyncio.Queue()
results: dict[str, dict] = {}
classifier = None # will be initialized in lifespan
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
logger.info("Starting the application...")
logger.info(f"Using model: {MODEL_NAME}")
logger.info(f"Batch process interval: {BATCH_PROCESS_INTERVAL}")
logger.info(f"Max batch size: {MAX_BATCH_SIZE}")
# ----------------------------- #
# Model Initialization #
# ----------------------------- #
def load_classifier(model_name: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = ORTModelForSequenceClassification.from_pretrained(
model_name,
)
gc.collect()
return pipeline(
task="text-classification",
accelerator="ort",
model=model,
tokenizer=tokenizer,
framework="pt",
batch_size=MAX_BATCH_SIZE,
num_workers=1,
)
# ----------------------------- #
# Pydantic Schema #
# ----------------------------- #
class Query(BaseModel):
sentence: str
# ----------------------------- #
# Queue Processing Task #
# ----------------------------- #
async def process_queue():
while True:
await asyncio.sleep(BATCH_PROCESS_INTERVAL)
batch = []
while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE:
batch.append(await query_queue.get())
if not batch:
continue
sentences = [item["sentence"] for item in batch]
ids = [item["id"] for item in batch]
predictions = classifier(sentences)
async with lock:
results.update(
{
query_id: {
"sentence": sentence,
"label": pred["label"],
"score": pred["score"],
}
for query_id, pred, sentence in zip(ids, predictions, sentences)
}
)
# ----------------------------- #
# Lifespan Handler #
# ----------------------------- #
@asynccontextmanager
async def lifespan(_: FastAPI):
global classifier
classifier = load_classifier(MODEL_NAME)
_ = classifier("Hi")
logger.info("Model loaded successfully.")
queue_task = asyncio.create_task(process_queue())
yield
queue_task.cancel()
logger.info("Shutting down the application...")
logger.info("Model unloaded successfully.")
classifier = None
gc.collect()
try:
await queue_task
except asyncio.CancelledError:
pass
# ----------------------------- #
# FastAPI Setup #
# ----------------------------- #
app = FastAPI(lifespan=lifespan)
# ----------------------------- #
# API Endpoints #
# ----------------------------- #
@app.post("/classify")
async def classify(query: Query):
logger.info(f"{query.sentence}")
query_id = uuid.uuid4().hex
await query_queue.put({"id": query_id, "sentence": query.sentence})
while True:
async with lock:
if query_id in results:
return {"id": query_id, "result": results.pop(query_id)}
await asyncio.sleep(0.1)
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def read_root():
return FileResponse(path="static/index.html", media_type="text/html")