|
import gc |
|
import uuid |
|
import logging |
|
import asyncio |
|
import transformers |
|
from fastapi import FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
from pydantic import BaseModel |
|
from contextlib import asynccontextmanager |
|
from transformers import ( |
|
AutoTokenizer, |
|
) |
|
from optimum.pipelines import pipeline |
|
from optimum.onnxruntime import ORTModelForSequenceClassification |
|
|
|
|
|
|
|
|
|
transformers.set_seed(42) |
|
|
|
MODEL_NAME = "distilroberta-base-climate-sentiment-onnx-quantized" |
|
BATCH_PROCESS_INTERVAL = 0.05 |
|
MAX_BATCH_SIZE = 16 |
|
|
|
|
|
|
|
|
|
lock = asyncio.Lock() |
|
query_queue: asyncio.Queue = asyncio.Queue() |
|
results: dict[str, dict] = {} |
|
classifier = None |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info("Starting the application...") |
|
logger.info(f"Using model: {MODEL_NAME}") |
|
logger.info(f"Batch process interval: {BATCH_PROCESS_INTERVAL}") |
|
logger.info(f"Max batch size: {MAX_BATCH_SIZE}") |
|
|
|
|
|
|
|
|
|
|
|
def load_classifier(model_name: str): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = ORTModelForSequenceClassification.from_pretrained( |
|
model_name, |
|
) |
|
|
|
gc.collect() |
|
return pipeline( |
|
task="text-classification", |
|
accelerator="ort", |
|
model=model, |
|
tokenizer=tokenizer, |
|
framework="pt", |
|
batch_size=MAX_BATCH_SIZE, |
|
num_workers=1, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
class Query(BaseModel): |
|
sentence: str |
|
|
|
|
|
|
|
|
|
|
|
async def process_queue(): |
|
while True: |
|
await asyncio.sleep(BATCH_PROCESS_INTERVAL) |
|
|
|
batch = [] |
|
while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE: |
|
batch.append(await query_queue.get()) |
|
|
|
if not batch: |
|
continue |
|
|
|
sentences = [item["sentence"] for item in batch] |
|
ids = [item["id"] for item in batch] |
|
predictions = classifier(sentences) |
|
|
|
async with lock: |
|
results.update( |
|
{ |
|
query_id: { |
|
"sentence": sentence, |
|
"label": pred["label"], |
|
"score": pred["score"], |
|
} |
|
for query_id, pred, sentence in zip(ids, predictions, sentences) |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(_: FastAPI): |
|
global classifier |
|
classifier = load_classifier(MODEL_NAME) |
|
_ = classifier("Hi") |
|
logger.info("Model loaded successfully.") |
|
queue_task = asyncio.create_task(process_queue()) |
|
yield |
|
queue_task.cancel() |
|
logger.info("Shutting down the application...") |
|
logger.info("Model unloaded successfully.") |
|
classifier = None |
|
gc.collect() |
|
try: |
|
await queue_task |
|
except asyncio.CancelledError: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/classify") |
|
async def classify(query: Query): |
|
logger.info(f"{query.sentence}") |
|
query_id = uuid.uuid4().hex |
|
await query_queue.put({"id": query_id, "sentence": query.sentence}) |
|
|
|
while True: |
|
async with lock: |
|
if query_id in results: |
|
return {"id": query_id, "result": results.pop(query_id)} |
|
await asyncio.sleep(0.1) |
|
|
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return FileResponse(path="static/index.html", media_type="text/html") |
|
|